diff --git a/.gitignore b/.gitignore index b42a36a472..58df171a81 100644 --- a/.gitignore +++ b/.gitignore @@ -201,4 +201,4 @@ debug_I2_zero_band outputs/ -third_party/ \ No newline at end of file +third_party/ diff --git a/configs/glm51_math_nixl_dpep16_10node/rl.toml b/configs/glm51_math_nixl_dpep16_10node/rl.toml new file mode 100644 index 0000000000..a028c43a23 --- /dev/null +++ b/configs/glm51_math_nixl_dpep16_10node/rl.toml @@ -0,0 +1,90 @@ +output_dir = "outputs/glm51-math-nixl-dpep16-10node-nixlfix2" +seq_len = 2048 +max_steps = 500 + +[log] +level = "debug" + +[wandb] +project = "glm51-math" +name = "glm51-math-nixl-dpep16-10node-nixlfix2" + +[weight_broadcast] +type = "nixl_mx" +timeout = 12000 + +[deployment] +type = "multi_node" +num_train_nodes = 8 +num_infer_nodes = 2 +gpus_per_node = 8 + +[slurm] +job_name = "glm51-math-nixl-dpep16" +partition = "cluster" + +[trainer] +dist_timeout_seconds = 7200 + +[trainer.model] +name = "zai-org/GLM-5.1" +impl = "custom" +attn = "flash_attention_2" +fused_lm_head_token_chunk_size = 2048 +optim_cpu_offload = true +ep = 8 + +[trainer.model.compile] + +[trainer.model.ac] +freq = 1 + +[trainer.model.ac_offloading] +max_inflight_activations = 1 + +[trainer.optim] +type = "sign_sgd" +lr = 1e-6 +weight_decay = 0.1 + +[orchestrator] +batch_size = 64 +rollouts_per_example = 16 +oversampling_factor = 2 +max_off_policy_steps = 8 +filters = [] + +[orchestrator.model] +name = "zai-org/GLM-5.1-FP8" + +[orchestrator.renderer] +name = "glm-5.1" + +[orchestrator.train.sampling] +max_completion_tokens = 2048 + +[[orchestrator.train.env]] +id = "math-env" +name = "hendrycks-math" +args = { dataset_name = "PrimeIntellect/Hendrycks-Math", dataset_subset = "default", math_verify_max_workers = 128, math_verify_timeout = 60 } + +[orchestrator.buffer] +easy_threshold = 1.0 +hard_threshold = 0.0 +seed = 42 + +[inference] +enable_expert_parallel = true +all2all_backend = "allgather_reducescatter" +gpu_memory_utilization = 0.80 +enable_eplb = false +use_deep_gemm = true + +[inference.model] +name = "zai-org/GLM-5.1-FP8" +tool_call_parser = "glm47" +max_model_len = 2048 + +[inference.parallel] +tp = 1 +dp = 16 diff --git a/docker/modelexpress/docker-compose.yml b/docker/modelexpress/docker-compose.yml new file mode 100644 index 0000000000..cf80fbd237 --- /dev/null +++ b/docker/modelexpress/docker-compose.yml @@ -0,0 +1,31 @@ +services: + redis: + image: redis:8-alpine + network_mode: host + command: ["redis-server", "--port", "29502"] + healthcheck: + test: ["CMD", "redis-cli", "-p", "29502", "ping"] + interval: 5s + retries: 3 + start_period: 5s + restart: "no" + + modelexpress-server: + build: + context: https://github.com/ai-dynamo/modelexpress.git#b0c94ed61c65d2c2355a18508977941b9946d8b5 + network_mode: host + environment: + - MODEL_EXPRESS_SERVER_PORT=29501 + - MODEL_EXPRESS_LOG_LEVEL=info + - MX_METADATA_BACKEND=redis + - REDIS_URL=redis://localhost:29502 + - MX_HEARTBEAT_TIMEOUT_SECS=86400 + healthcheck: + test: ["CMD", "bash", "-c", " 1: self.inference.api_server_count = dp_per_node - if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl": + if self.weight_broadcast is not None and self.weight_broadcast.type in ("nccl", "nixl_mx"): # Compute inference_world_size from actual worker count per server: # each api_server runs tp workers that participate in collective_rpc. api_server_count = self.inference.api_server_count if self.inference else 1 tp = self.inference.parallel.tp if self.inference else 1 total_infer_workers = self.deployment.total_infer_nodes * api_server_count * tp - assert self.trainer.weight_broadcast.type == "nccl" - self.trainer.weight_broadcast.host = "0.0.0.0" + if self.trainer.weight_broadcast.type == "nccl": + self.trainer.weight_broadcast.host = "0.0.0.0" self.trainer.weight_broadcast.inference_world_size = total_infer_workers - assert self.orchestrator.weight_broadcast.type == "nccl" self.orchestrator.weight_broadcast.inference_world_size = total_infer_workers return self @@ -601,10 +630,8 @@ def auto_setup_disaggregated_inference(self): ) total_infer_gpus = self.deployment.total_infer_nodes * self.deployment.gpus_per_node - if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl": - assert self.trainer.weight_broadcast.type == "nccl" + if self.weight_broadcast is not None and self.weight_broadcast.type in ("nccl", "nixl_mx"): self.trainer.weight_broadcast.inference_world_size = total_infer_gpus - assert self.orchestrator.weight_broadcast.type == "nccl" self.orchestrator.weight_broadcast.inference_world_size = total_infer_gpus return self diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index f4d37cd9d0..fa96638a70 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -474,8 +474,31 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig): """Use kernel-format FP8 quantized NCCL transfer for weight updates. When disabled, uses default HF checkpoint-format transfer.""" +class NIXLMxWeightBroadcastConfig(BaseWeightBroadcastConfig): + """Configures NIXL (UCX/RDMA) weight transfer with Model Express rendezvous.""" + + type: Literal["nixl_mx"] = "nixl_mx" + + host: str = "localhost" + """Host for the Model Express rendezvous server.""" + + port: int = 29501 + """Port for the Model Express rendezvous server.""" + + timeout: int = 1200 + """Timeout in seconds for rendezvous and per-step transfers.""" + + # TODO: Should not be configurable, but auto-inferred + inference_world_size: int = 1 + """Number of GPUs used for inference.""" + + inference_model_name: str = "" + """HF model name or local path of the inference target.""" + + WeightBroadcastConfig: TypeAlias = Annotated[ - FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type") + FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig, + Field(discriminator="type"), ] diff --git a/pyproject.toml b/pyproject.toml index b8c1500971..6f8057dff2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ disagg = [ "nixl", "nixl-cu12 ; platform_machine == 'x86_64'", "vllm-router ; platform_machine == 'x86_64'", + "modelexpress", ] gpt-oss = [ "kernels", @@ -172,6 +173,7 @@ override-dependencies = [ "transformers>=5.1.0.dev0", "torch>=2.9.0", "openenv-core", + "protobuf>=6.31.1", ] [tool.uv.exclude-newer-package] @@ -238,6 +240,7 @@ vllm = [ deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" } deep-gemm = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" } nixl-cu12 = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" } +modelexpress = { git = "https://github.com/ai-dynamo/modelexpress.git", subdirectory = "modelexpress_client/python", rev = "b0c94ed" } flash-linear-attention = { git = "https://github.com/fla-org/flash-linear-attention" } flash_attn_3 = { index = "pytorch-cu128-test" } diff --git a/src/prime_rl/_compat.py b/src/prime_rl/_compat.py index 268c38376b..22b3017e30 100644 --- a/src/prime_rl/_compat.py +++ b/src/prime_rl/_compat.py @@ -4,6 +4,28 @@ Each shim documents the upstream issue and removal condition. """ +# --------------------------------------------------------------------------- +# tilelang ships a stub libcudart that proxies to the real CUDA runtime via +# dlsym(RTLD_DEFAULT, ...). If the stub's symbols are the first ones found +# (because nothing has loaded the real libcudart globally yet) its self-check +# fails and the stub aborts — hit the moment any code calls into the +# classic-cudaMalloc MemPool (used for NIXL-registered slot buffers). +# +# Preloading the real library with RTLD_GLOBAL at this very early point — +# before transformers/torch/tilelang are pulled into the process — makes +# dlsym find the real symbols first. +# +# Wrapped in try/except because CDLL can fail on machines without a real +# CUDA runtime (e.g. CI containers). +# --------------------------------------------------------------------------- +import ctypes as _ctypes + +try: + _ctypes.CDLL("libcudart.so", mode=_ctypes.RTLD_GLOBAL) +except OSError: + pass + + # --------------------------------------------------------------------------- # ring_flash_attn + transformers >= 5.4 # diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index 582d17116e..325cfad4ce 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -387,6 +387,7 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) -> config_path=config_dir / RL_TOML, output_dir=config.output_dir, gpus_per_node=config.deployment.gpus_per_node, + use_nixl_mx_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nixl_mx", ) elif config.inference is not None and config.inference.deployment.type == "disaggregated": infer_deploy = config.inference.deployment @@ -420,6 +421,7 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) -> if config.inference.kv_cache_offload else 0, use_nccl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nccl", + use_nixl_mx_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nixl_mx", wandb_shared=config.wandb is not None and config.wandb.shared, ranks_filter=",".join(map(str, config.trainer.log.ranks_filter)), ) @@ -444,6 +446,7 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) -> dp_per_node=(config.deployment.gpus_per_node // config.inference.parallel.tp) if config.inference else 1, kv_offload=config.inference is not None and config.inference.kv_cache_offload is not None, use_nccl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nccl", + use_nixl_mx_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nixl_mx", wandb_shared=config.wandb is not None and config.wandb.shared, ranks_filter=",".join(map(str, config.trainer.log.ranks_filter)), ) diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index 798defc990..e8a4e8c4bf 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -1,4 +1,5 @@ import torch +from vllm.triton_utils import tl, triton from prime_rl.inference.vllm.padded_input_scrub import monkey_patch_vllm_padded_input_scrub @@ -16,6 +17,7 @@ def transformers_v5_compat(): _patch_qwen35_lora() _patch_lora_key_prefix() + monkey_patch_deep_gemm_silu_mul_quant_int64() monkey_patch_dp_engine_core_pause_resume_deadlock() monkey_patch_vllm_layerwise_reload_alias_buffers() monkey_patch_vllm_padded_input_scrub() @@ -53,6 +55,147 @@ def _copy_and_restore_kernel_tensors(layer: torch.nn.Module, info: reload_layerw logger.warning("Enabled vLLM layerwise reload alias-buffer patch.") +@triton.jit +def _silu_mul_per_token_group_quant_fp8_colmajor_int64_kernel( + y_ptr, + y_q_ptr, + y_s_ptr, + M: tl.int64, + N: tl.int64, + y_s_col_stride: tl.int64, + eps, + clamp_limit, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, + HAS_CLAMP: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + N_2 = N // 2 + + m_offset = (pid_m * BLOCK_M).to(tl.int64) + n_offset = (pid_n * BLOCK_N).to(tl.int64) + if m_offset >= M: + return + + offs_n = tl.arange(0, BLOCK_N).to(tl.int64) + offs_m = tl.arange(0, BLOCK_M).to(tl.int64) + + base_y_ptr = y_ptr + m_offset * N + n_offset + act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :] + + act_in = tl.load(act_in_ptrs) + mul_in = tl.load(act_in_ptrs + N_2) + + if HAS_CLAMP: + act_in = tl.minimum(act_in.to(tl.float32), clamp_limit).to(y_ptr.dtype.element_ty) + mul_in = tl.clamp(mul_in.to(tl.float32), -clamp_limit, clamp_limit).to(y_ptr.dtype.element_ty) + act_in = act_in.to(tl.float32) + one_f32 = tl.cast(1, tl.float32) + silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty) + y = (silu_out * mul_in).to(tl.float32) + + absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps) + scale_raw = absmax * (1.0 / fp8_max) + y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw + y_s = tl.reshape(y_s, (BLOCK_M, 1)) + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset + y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :] + tl.store(y_q_ptrs, y_q) + + group_id = n_offset // GROUP_SIZE + base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset + y_s_ptrs = base_y_s_ptr + offs_m + y_s = tl.reshape(y_s, (BLOCK_M,)) + tl.store(y_s_ptrs, y_s) + + +def _silu_mul_per_token_group_quant_fp8_colmajor_int64( + input: torch.Tensor, + output: torch.Tensor | None = None, + use_ue8m0: bool | None = None, + eps: float = 1e-10, + clamp_limit: float | None = None, +): + from vllm.platforms import current_platform + from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + + group_size = 128 + assert input.ndim == 2 + if output is not None: + assert output.ndim == 2 + assert input.size(0) % group_size == 0 + assert input.size(1) % (group_size * 2) == 0 + + if use_ue8m0 is None: + use_ue8m0 = is_deep_gemm_e8m0_used() + + M, N = input.size() + N_2 = N // 2 + + fp8_dtype = current_platform.fp8_dtype() + if output is None: + output = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device) + + output_scales = torch.empty(((N_2 // group_size), M), dtype=torch.float32, device=input.device).transpose(0, 1) + + block_m = 8 + block_n = group_size + assert M % block_m == 0 + assert N_2 % block_n == 0 + + finfo = torch.finfo(fp8_dtype) + fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max + + has_clamp = clamp_limit is not None + grid = (M // block_m, N_2 // block_n) + _silu_mul_per_token_group_quant_fp8_colmajor_int64_kernel[grid]( + input, + output, + output_scales, + M, + N, + output_scales.stride(-1), + eps, + clamp_limit if has_clamp else 0.0, + fp8_min, + fp8_max, + use_ue8m0, + has_clamp, + group_size, + block_m, + block_n, + ) + + return output, output_scales + + +def monkey_patch_deep_gemm_silu_mul_quant_int64(): + import sys + + from vllm.logger import init_logger + from vllm.model_executor.layers.quantization.utils import fp8_utils + + logger = init_logger(__name__) + + fp8_utils.silu_mul_per_token_group_quant_fp8_colmajor = _silu_mul_per_token_group_quant_fp8_colmajor_int64 + + deep_gemm_moe_module = sys.modules.get("vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe") + if deep_gemm_moe_module is not None: + deep_gemm_moe_module.silu_mul_per_token_group_quant_fp8_colmajor = ( + _silu_mul_per_token_group_quant_fp8_colmajor_int64 + ) + + logger.warning("Enabled int64-addressing Triton patch for vLLM DeepGEMM SiLU/mul FP8 quant.") + + def _patch_qwen35_lora(): """Fix Qwen3.5 LoRA: align packed_modules_mapping with output_sizes. diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index 1e312cfc5a..4db432471f 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -167,6 +167,7 @@ def models(request: Request) -> OpenAIServingModels: WORKER_EXTENSION_CLS = { "nccl": "prime_rl.inference.vllm.worker.nccl.NCCLWeightUpdateWorker", "filesystem": "prime_rl.inference.vllm.worker.filesystem.FileSystemWeightUpdateWorker", + "nixl_mx": "prime_rl.inference.vllm.worker.nixl_mx.NIXLMxWeightUpdateWorker", } @@ -228,6 +229,16 @@ async def init_broadcaster(request: Request): return {"status": "ok"} +@router.post("/init_nixl_mx") +async def init_nixl_mx(request: Request): + data = await request.json() + await engine_client(request).collective_rpc( + "init_nixl_mx", + args=(data["host"], data["port"], data["rank_offset"]), + ) + return {"status": "ok"} + + async def custom_init_app_state( engine_client: EngineClient, state: State, diff --git a/src/prime_rl/inference/vllm/worker/nixl_mx.py b/src/prime_rl/inference/vllm/worker/nixl_mx.py new file mode 100644 index 0000000000..b23e8222cb --- /dev/null +++ b/src/prime_rl/inference/vllm/worker/nixl_mx.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import msgspec +import torch +from modelexpress import p2p_pb2 +from modelexpress.client import MxClient +from torch.nn import Module +from vllm.logger import init_logger + +from prime_rl.inference.vllm.worker.weight_transfer import build_expert_map, update_mla_absorbed_weights +from prime_rl.transport.mx_rendezvous import MxRendezvous +from prime_rl.transport.nixl_agent import NixlAgentWrapper, make_agent_name, pin_ucx_rail +from prime_rl.transport.wire import RendezvousPayload + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_worker import Worker + + Worker = Worker # type: ignore +else: + Worker = object # type: ignore + +logger = init_logger("vllm.inference.vllm.worker_nixl_mx") + + +class NIXLMxWeightUpdateWorker(Worker): + """vLLM worker extension for in-place weight updates over NIXL + MX.""" + + @property + def raw_model(self) -> Module: + model_runner = self.model_runner + model = model_runner.model.runnable if hasattr(model_runner.model, "runnable") else model_runner.model + assert isinstance(model, Module) + return model + + def register_tensors_with_nixl(self, model: Module) -> None: + self._descriptors: list[p2p_pb2.TensorDescriptor] = [] + live_tensors: dict[str, torch.Tensor] = {} + for name, param in model.named_parameters(): + if not param.is_contiguous(): + raise RuntimeError(f"non-contiguous param {name} cannot be NIXL-registered") + live_tensors[name] = param.data + for name, buf in model.named_buffers(): + if name in live_tensors or not name.endswith("_weight_scale_inv"): + continue + if not buf.is_contiguous(): + raise RuntimeError(f"non-contiguous buffer {name} cannot be NIXL-registered") + live_tensors[name] = buf + + for name, tensor in live_tensors.items(): + self.agent.register_tensor(tensor) + self._descriptors.append(self.agent.make_tensor_descriptor(name, tensor)) + + def init_nixl_mx(self, host: str, port: int, rank_offset: int) -> None: + local_rank = self.device.index + global_rank = rank_offset + local_rank + inference_model_name = self.model_runner.model_config.model + + pin_ucx_rail(local_rank) + self.agent = NixlAgentWrapper(name=make_agent_name("inference", global_rank)) + self.rendezvous = MxRendezvous( + client=MxClient(server_url=f"{host}:{port}"), + role="inference", + rank=global_rank, + peer_world_size=0, + model_name=inference_model_name, + ) + + expert_map = {k: v.cpu().tolist() for k, v in build_expert_map(self.raw_model).items()} + + self.register_tensors_with_nixl(self.raw_model) + payload = RendezvousPayload( + agent_metadata=self.agent.get_metadata(), + agent_name=self.agent.name, + expert_map=expert_map, + ) + self.rendezvous.publish( + nixl_metadata=msgspec.msgpack.encode(payload), + tensors=self._descriptors, + ) + self.rendezvous.set_status(p2p_pb2.SOURCE_STATUS_READY) + + logger.info( + f"NIXL+MX init: rank={global_rank} tensors={len(self._descriptors)} " + f"experts={sum(len(v) for v in expert_map.values())} model={inference_model_name}" + ) + + @torch.no_grad() + def update_weights_from_path(self, weight_dir: str | None = None) -> None: + """Block until the trainer's RDMA push completes, then recompute the MLA absorbed weights and return, orchestrator can then call `/resume`""" + self.rendezvous.wait_for_all_peers_ready(timeout=1200) + torch.cuda.synchronize(self.device) + update_mla_absorbed_weights(self.raw_model) + logger.info("Weight update applied (NIXL+MX)") diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 1b0cb4b3ee..46f6cf6728 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -33,6 +33,7 @@ import pandas as pd import verifiers as vf +from modelexpress.client import MxClient from renderers.base import create_renderer from transformers import AutoProcessor @@ -54,8 +55,10 @@ save_rollouts, ) from prime_rl.trainer.model import setup_tokenizer +from prime_rl.transport.mx_rendezvous import MxRendezvous from prime_rl.utils.client import ( init_nccl_broadcast, + init_nixl_mx_broadcast, setup_inference_pool, ) from prime_rl.utils.config import cli @@ -278,6 +281,22 @@ async def orchestrate(config: OrchestratorConfig): inference_world_size=config.weight_broadcast.inference_world_size, quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer, ) + elif config.weight_broadcast.type == "nixl_mx": + await init_nixl_mx_broadcast( + student_inference.admin_clients, + config.weight_broadcast.host, + config.weight_broadcast.port, + inference_world_size=config.weight_broadcast.inference_world_size, + ) + mx_rendezvous = MxRendezvous( + client=MxClient(server_url=f"{config.weight_broadcast.host}:{config.weight_broadcast.port}"), + role="orchestrator", + rank=0, + peer_world_size=1, + model_name=config.student.model.name, + ) + mx_rendezvous.publish() + scheduler.mx_rendezvous = mx_rendezvous # Setup training batch sender for sending training examples to trainer logger.info(f"Initializing training batch sender ({config.rollout_transport})") @@ -303,8 +322,8 @@ async def orchestrate(config: OrchestratorConfig): # Allow eval at resumed step by setting prev_ckpt_step one behind prev_ckpt_step = scheduler.ckpt_step - 1 - # In NCCL mode, skip existence check - weights are broadcasted, not stored on disk - check_exists = config.weight_broadcast.type != "nccl" + # In NCCL/NIXL modes, skip existence check - weights are pushed, not stored on disk + check_exists = config.weight_broadcast.type not in ("nccl", "nixl_mx") wait_timeout = config.ckpt.wait_for_weights_timeout if config.ckpt else None weights_path = get_weight_dir( config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index 02840b6443..5ce43281dd 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -7,6 +7,7 @@ import verifiers as vf from aiolimiter import AsyncLimiter +from modelexpress import p2p_pb2 from prime_rl.configs.orchestrator import OrchestratorConfig from prime_rl.orchestrator.buffer import Buffer @@ -101,6 +102,7 @@ def __init__( self.strict_async_level = strict_async_level self.lora_name = lora_name self.json_logging = config.log.json_logging + self.mx_rendezvous = None # student_inference is the weight-sync target. teacher_inference is set # in opd (for logprobs) and sft (for rollouts). rollout_inference is @@ -316,7 +318,14 @@ async def _apply_policy_update(self, next_ckpt_step: int) -> None: ) self.checkpoint_ready.clear() wait_for_ckpt_start_time = time.perf_counter() - await wait_for_path(get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) / "STABLE") + if self.mx_rendezvous is not None: + await asyncio.to_thread( + self.mx_rendezvous.wait_for_all_peers_ready, + role="trainer", + status=p2p_pb2.SOURCE_STATUS_INITIALIZING, + ) + else: + await wait_for_path(get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) / "STABLE") self.wait_for_ckpt_time = time.perf_counter() - wait_for_ckpt_start_time self.logger.info( f"Orchestrator resumed: checkpoint {next_ckpt_step} ready (after {self.wait_for_ckpt_time:.2f}s)" @@ -327,8 +336,17 @@ async def _apply_policy_update(self, next_ckpt_step: int) -> None: ) update_weights_start_time = time.perf_counter() - weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) - await self.student_inference.update_weights(weights_path, lora_name=self.lora_name, step=next_ckpt_step) + if self.mx_rendezvous is not None: + weights_path = None + signal_trainer = lambda: self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_READY) + else: + weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) + signal_trainer = None + await self.student_inference.update_weights( + weights_path, lora_name=self.lora_name, step=next_ckpt_step, on_engines_paused=signal_trainer + ) + if self.mx_rendezvous is not None: + self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_INITIALIZING) self.update_weights_time = time.perf_counter() - update_weights_start_time self.logger.debug(f"Updated weights to step {next_ckpt_step} in {self.update_weights_time:.2f}s") diff --git a/src/prime_rl/templates/multi_node_rl.sbatch.j2 b/src/prime_rl/templates/multi_node_rl.sbatch.j2 index 8079d394c7..cd5d5d6fc3 100755 --- a/src/prime_rl/templates/multi_node_rl.sbatch.j2 +++ b/src/prime_rl/templates/multi_node_rl.sbatch.j2 @@ -20,6 +20,9 @@ #SBATCH --exclusive #SBATCH --output={{ output_dir }}/job_%j.log #SBATCH --error={{ output_dir }}/job_%j.log +{%- if use_nixl_mx_broadcast %} +#SBATCH --signal=B:TERM@30 +{%- endif %} # Configs export NUM_TRAIN_NODES={{ num_train_nodes }} @@ -147,6 +150,22 @@ cd $PROJECT_DIR source .venv/bin/activate uv sync --all-extras +{% if use_nixl_mx_broadcast %} +# Launch Model Express server (rendezvous + metadata store) on the trainer head node. +# trap fires on normal exit, scancel/preempt/timeout (SIGTERM via #SBATCH --signal=B:TERM@30), +# Ctrl+C (SIGINT), and shell hangup (SIGHUP). SIGKILL / node failure cannot be caught. +export MX_COMPOSE_FILE=$PROJECT_DIR/docker/modelexpress/docker-compose.yml +export MX_COMPOSE_PROJECT=prime-rl-mx-${SLURM_JOB_ID} +trap 'srun --overlap --nodes=1 --ntasks=1 --nodelist=$MASTER_ADDR \ + docker compose -f "$MX_COMPOSE_FILE" -p "$MX_COMPOSE_PROJECT" down --remove-orphans >/dev/null 2>&1 || true' EXIT INT TERM HUP +echo "Starting Model Express server on $MASTER_ADDR" +# Kill ALL stale prime-rl-mx containers from previous jobs. +srun --overlap --nodes=1 --ntasks=1 --nodelist=$MASTER_ADDR \ + bash -c 'docker ps -aq --filter "name=prime-rl-mx" | xargs -r docker rm -f >/dev/null 2>&1 || true' +srun --overlap --nodes=1 --ntasks=1 --nodelist=$MASTER_ADDR \ + docker compose -f "$MX_COMPOSE_FILE" -p "$MX_COMPOSE_PROJECT" up -d --build --wait +{% endif %} + {% if pre_run_command %} # Pre-run command {{ pre_run_command }} @@ -188,7 +207,10 @@ srun bash -c ' # Infiniband setup IB_HCA=$(ibv_devinfo | sed -n -e "/hca_id/p" -e "/link_layer:/p" | grep -B1 InfiniBand | grep hca_id | sed -e "s/^hca_id://g" | tr -d "[[:blank:]]" |paste -sd,) export NCCL_IB_HCA=$IB_HCA - +{% if use_nixl_mx_broadcast %} + export MX_RDMA_NIC_PIN=auto + export LD_LIBRARY_PATH="$PROJECT_DIR/third_party/ucx/lib:$PROJECT_DIR/third_party/ucx/lib/ucx:${LD_LIBRARY_PATH:-}" +{% endif %} {% if num_infer_nodes > 0 -%} if [ "$SLURM_PROCID" -lt "$NUM_INFER_NODES" ]; then @@ -364,9 +386,9 @@ else if [ "$TRAIN_NODE_RANK" -eq 0 ]; then ORCHESTRATOR_ARGS="@ $CONFIG_DIR/orchestrator.toml" - ORCHESTRATOR_ARGS="$ORCHESTRATOR_ARGS --client.base-url $INFER_URLS" - ORCHESTRATOR_ARGS="$ORCHESTRATOR_ARGS --client.admin-base-url $ADMIN_URLS" - {% if use_nccl_broadcast %}ORCHESTRATOR_ARGS="$ORCHESTRATOR_ARGS --weight_broadcast.host $MASTER_ADDR" + ORCHESTRATOR_ARGS="$ORCHESTRATOR_ARGS --student.client.base-url $INFER_URLS" + ORCHESTRATOR_ARGS="$ORCHESTRATOR_ARGS --student.client.admin-base-url $ADMIN_URLS" + {% if use_nccl_broadcast or use_nixl_mx_broadcast %}ORCHESTRATOR_ARGS="$ORCHESTRATOR_ARGS --weight_broadcast.host $MASTER_ADDR" {% endif %}{% if wandb_shared %}WANDB_SHARED_LABEL=orchestrator {% endif %}uv run orchestrator $ORCHESTRATOR_ARGS \ 2>&1 | tee $OUTPUT_DIR/logs/orchestrator.log & fi @@ -394,6 +416,7 @@ else --local-ranks-filter={{ ranks_filter }} \ -m prime_rl.trainer.rl.train \ @ $CONFIG_DIR/trainer.toml \ - 2>&1 | sed -u 's/^\[[a-zA-Z]*[0-9]*\]://' | tee -a $OUTPUT_DIR/logs/trainer/node_${TRAIN_NODE_RANK}.log + {% if use_nixl_mx_broadcast %}--weight_broadcast.host $MASTER_ADDR \ + {% endif %}2>&1 | sed -u 's/^\[[a-zA-Z]*[0-9]*\]://' | tee -a $OUTPUT_DIR/logs/trainer/node_${TRAIN_NODE_RANK}.log {% if num_infer_nodes > 0 %} fi{% endif %} ' diff --git a/src/prime_rl/templates/single_node_rl.sbatch.j2 b/src/prime_rl/templates/single_node_rl.sbatch.j2 index 9bec929d38..908ce027a6 100644 --- a/src/prime_rl/templates/single_node_rl.sbatch.j2 +++ b/src/prime_rl/templates/single_node_rl.sbatch.j2 @@ -20,6 +20,9 @@ {%- endif %} #SBATCH --output={{ output_dir }}/job_%j.log #SBATCH --error={{ output_dir }}/job_%j.log +{%- if use_nixl_mx_broadcast %} +#SBATCH --signal=B:TERM@30 +{%- endif %} export PROJECT_DIR={{ project_dir }} @@ -29,8 +32,21 @@ cd $PROJECT_DIR source .venv/bin/activate uv sync --all-extras +{% if use_nixl_mx_broadcast %} +# Launch Model Express server (rendezvous + metadata store) co-located with the trainer. +# trap fires on normal exit, scancel/preempt/timeout (SIGTERM via #SBATCH --signal=B:TERM@30), +# Ctrl+C (SIGINT), and shell hangup (SIGHUP). SIGKILL / node failure cannot be caught. +export MX_COMPOSE_FILE=$PROJECT_DIR/docker/modelexpress/docker-compose.yml +export MX_COMPOSE_PROJECT=prime-rl-mx-${SLURM_JOB_ID} +trap 'docker compose -f "$MX_COMPOSE_FILE" -p "$MX_COMPOSE_PROJECT" down --remove-orphans >/dev/null 2>&1 || true' EXIT INT TERM HUP +docker ps -aq --filter "name=prime-rl-mx" | xargs -r docker rm -f >/dev/null 2>&1 || true +docker compose -f "$MX_COMPOSE_FILE" -p "$MX_COMPOSE_PROJECT" up -d --build --wait +export MX_RDMA_NIC_PIN=auto +export LD_LIBRARY_PATH="$PROJECT_DIR/third_party/ucx/lib:$PROJECT_DIR/third_party/ucx/lib/ucx:${LD_LIBRARY_PATH:-}" +{% endif %} + {% if pre_run_command %} # Pre-run command {{ pre_run_command }} {% endif %} -uv run rl @ {{ config_path }} +uv run rl @ {{ config_path }}{% if use_nixl_mx_broadcast %} --weight_broadcast.host localhost{% endif %} diff --git a/src/prime_rl/trainer/models/base.py b/src/prime_rl/trainer/models/base.py index df66984847..58df051607 100644 --- a/src/prime_rl/trainer/models/base.py +++ b/src/prime_rl/trainer/models/base.py @@ -1,6 +1,11 @@ +import torch from torch import Tensor from transformers.modeling_utils import PreTrainedModel +from prime_rl.trainer.models.conversion_spec import ConversionSpec +from prime_rl.trainer.models.slots import Slot, build_slots_for_conversion_spec +from prime_rl.trainer.parallel_dims import ParallelDims + class PreTrainedModelPrimeRL(PreTrainedModel): """ @@ -147,5 +152,45 @@ def init_buffers_post_meta(self) -> None: """ raise NotImplementedError(f"init_buffers_post_meta is not implemented for {self.__class__.__name__}") + def get_conversion_specs_for_layer(self, layer_idx: int) -> list[ConversionSpec]: + raise NotImplementedError(f"get_conversion_specs_for_layer is not implemented for {self.__class__.__name__}") + + @property + def non_layer_specs(self) -> tuple[ConversionSpec, ...]: + raise NotImplementedError(f"non_layer_specs is not implemented for {self.__class__.__name__}") + + def build_slots(self, parallel_dims: ParallelDims, default_conversion: str, base_dtype: torch.dtype) -> list[Slot]: + state_dict = self.state_dict() + slots: list[Slot] = [] + + for layer_idx in range(self.config.num_hidden_layers): + layer_prefix = f"model.layers.{layer_idx}" + conversion_specs = self.get_conversion_specs_for_layer(layer_idx) + + for spec in conversion_specs: + slots.extend( + build_slots_for_conversion_spec( + spec, + prefix=layer_prefix, + state_dict=state_dict, + parallel_dims=parallel_dims, + default_conversion=default_conversion, + base_dtype=base_dtype, + ) + ) + + for spec in self.non_layer_specs: + slots.extend( + build_slots_for_conversion_spec( + spec, + prefix="", + state_dict=state_dict, + parallel_dims=parallel_dims, + default_conversion=default_conversion, + base_dtype=base_dtype, + ) + ) + return slots + __all__ = ["PreTrainedModelPrimeRL"] diff --git a/src/prime_rl/trainer/models/conversion_spec.py b/src/prime_rl/trainer/models/conversion_spec.py new file mode 100644 index 0000000000..32edb48446 --- /dev/null +++ b/src/prime_rl/trainer/models/conversion_spec.py @@ -0,0 +1,74 @@ +"""Conversion specs describe how a trainer-side source tensor is transformed +into a vLLM-kernel-side destination tensor. + +* :class:`MaybeQuantize` — the *transformation* selector for one destination + slot. Carries an opaque ``conversion_type`` string (or ``None`` to let + :func:`prime_rl.trainer.models.conversions.resolve` pick the default + based on the destination dtype). All conversion-specific data — block + size, scale layout, kernel dispatch — lives in the conversion registry, + not on the spec. +* :class:`ConversionSpec` — the *routing* for one logical parameter: + which source tensors fuse into which vLLM destination, along which axis, + using which :class:`MaybeQuantize`. + +This module is model-agnostic. Per-model spec tables live next to the +model's converter and reuse the primitives here. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class MaybeQuantize: + """Selects a conversion. ``None`` lets the registry pick the default + based on the destination dtype. + """ + + conversion_type: str | None = None + + +@dataclass(frozen=True) +class ConversionSpec: + """How one trainer-side logical parameter converts to its vLLM destination. + + Attributes: + dst: Destination suffix after ``model.layers.{i}.``. E.g. + ``"self_attn.qkv_proj.weight"``. + sources: One or more source suffixes (after ``model.layers.{i}.``) + that fuse into ``dst``. Fused along ``cat_dim``. + cat_dim: Axis along which multiple ``sources`` are concatenated. + conversion: Conversion selector. Default leaves the choice to the + registry; override to pin e.g. ``MaybeQuantize("passthrough")`` + for tensors that must never be quantized regardless of the + inference variant. + """ + + dst: str + sources: tuple[str, ...] + cat_dim: int = 0 + conversion: MaybeQuantize = field(default_factory=MaybeQuantize) + + @property + def is_expert_spec(self) -> bool: + """True iff this spec produces a fused stacked-expert slot.""" + return self.dst.startswith("mlp.experts.") + + @staticmethod + def scale_name(weight_name: str, *, allow_direct_parameter: bool = False) -> str: + """Paired scale buffer name for a weight buffer. + + Mirrors vLLM's FP8 naming: ``.weight`` → ``.weight_scale_inv`` for + 2D linears, ``_weight`` → ``_weight_scale_inv`` for 3D stacked-expert + buffers. Trainer source tensors can also be direct ``nn.Parameter`` + entries without a ``.weight`` suffix; callers must opt into that case + because inference destination names should stay strict. + """ + if weight_name.endswith(".weight"): + return weight_name.removesuffix(".weight") + ".weight_scale_inv" + if weight_name.endswith("_weight"): + return weight_name.removesuffix("_weight") + "_weight_scale_inv" + if allow_direct_parameter: + return f"{weight_name}.weight_scale_inv" + raise ValueError(f"cannot derive scale name from {weight_name!r}") diff --git a/src/prime_rl/trainer/models/conversions/__init__.py b/src/prime_rl/trainer/models/conversions/__init__.py new file mode 100644 index 0000000000..8bd86f41b9 --- /dev/null +++ b/src/prime_rl/trainer/models/conversions/__init__.py @@ -0,0 +1,109 @@ +"""Registry of named conversion kernels for trainer→inference weight transfer. + +A conversion is a function that writes one source tensor into one destination +tensor, optionally producing a paired scale buffer. Each conversion is +registered under a string name (e.g. ``"fp8_128x128"``). + +Resolution flow at startup: + +1. The trainer reads the inference model's HF ``config.json`` and calls + :func:`select_default_conversion` to pick one conversion name to use as + the default for every spec that doesn't pin its own. The choice is + driven entirely by ``config.quantization_config`` (or its absence). +2. For each :class:`~prime_rl.trainer.models.conversion_spec.ConversionSpec`, + :func:`resolve` returns the registry entry — explicit ``conversion_type`` + on the spec wins, otherwise the startup-chosen default applies. + +The registry never inspects the destination buffer's dtype; dtype is the +slot allocator's concern. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +import torch +from torch import Tensor +from transformers import AutoConfig + +ConversionFn = Callable[[Tensor, Tensor, "Tensor | None"], None] + + +@dataclass(frozen=True) +class ConversionEntry: + fn: ConversionFn + requires_scale: bool + dst_dtype: torch.dtype | None = None + preserve_source_dtype: bool = False + + +_REGISTRY: dict[str, ConversionEntry] = {} + + +def register( + name: str, + fn: ConversionFn, + *, + requires_scale: bool, + dst_dtype: torch.dtype | None = None, + preserve_source_dtype: bool = False, +) -> None: + if name in _REGISTRY: + raise ValueError(f"conversion {name!r} is already registered") + _REGISTRY[name] = ConversionEntry( + fn=fn, + requires_scale=requires_scale, + dst_dtype=dst_dtype, + preserve_source_dtype=preserve_source_dtype, + ) + + +def get(name: str) -> ConversionEntry: + if name not in _REGISTRY: + raise KeyError(f"unknown conversion {name!r}; registered: {sorted(_REGISTRY)}") + return _REGISTRY[name] + + +def select_default_conversion(inference_model_name: str) -> str: + """Pick the default conversion name for the given inference model. + + Loads the HF config and inspects ``quantization_config``: + + * absent → ``"passthrough"`` (no quantization; trainer→inference is a + plain dtype cast). + * ``quant_method == "fp8"`` with ``weight_block_size == [128, 128]`` → + ``"fp8_128x128"``. + * anything else → :class:`NotImplementedError`. + """ + config = AutoConfig.from_pretrained(inference_model_name) + quant = getattr(config, "quantization_config", None) + if quant is None: + return "passthrough" + if hasattr(quant, "to_dict"): + quant = quant.to_dict() + method = quant["quant_method"] + block_size = tuple(quant.get("weight_block_size") or ()) + if method == "fp8" and block_size == (128, 128): + return "fp8_128x128" + raise NotImplementedError( + f"unsupported inference quantization: quant_method={method!r}, weight_block_size={block_size}" + ) + + +def resolve(conversion_type: str | None, default: str) -> ConversionEntry: + """Return the registry entry for a spec. Explicit name wins; otherwise ``default``.""" + return get(conversion_type or default) + + +from prime_rl.trainer.models.conversions import bf16_cast as _bf16_cast # noqa: E402, F401 +from prime_rl.trainer.models.conversions import fp8_blockwise as _fp8_blockwise # noqa: E402, F401 + +__all__ = [ + "ConversionEntry", + "ConversionFn", + "register", + "get", + "resolve", + "select_default_conversion", +] diff --git a/src/prime_rl/trainer/models/conversions/bf16_cast.py b/src/prime_rl/trainer/models/conversions/bf16_cast.py new file mode 100644 index 0000000000..9ffc847a2e --- /dev/null +++ b/src/prime_rl/trainer/models/conversions/bf16_cast.py @@ -0,0 +1,32 @@ +"""Plain dtype-cast conversion. Despite the name, casts to whatever dtype +the destination buffer is — bf16, fp32, etc. Registered as ``"passthrough"``. +""" + +from __future__ import annotations + +import torch +from torch import Tensor + +from prime_rl.trainer.models.conversions import register + + +def passthrough(src: Tensor, out: Tensor, scale_out: Tensor | None = None) -> None: + assert scale_out is None, "passthrough conversion takes no scale buffer" + out.copy_(src.to(out.dtype)) + + +register("passthrough", passthrough, requires_scale=False) + + +def float32_passthrough(src: Tensor, out: Tensor, scale_out: Tensor | None = None) -> None: + assert scale_out is None, "float32_passthrough conversion takes no scale buffer" + out.copy_(src.to(torch.float32)) + + +register( + "float32_passthrough", + float32_passthrough, + requires_scale=False, + dst_dtype=torch.float32, + preserve_source_dtype=True, +) diff --git a/src/prime_rl/trainer/models/conversions/fp8_blockwise.py b/src/prime_rl/trainer/models/conversions/fp8_blockwise.py new file mode 100644 index 0000000000..3a9256ab7d --- /dev/null +++ b/src/prime_rl/trainer/models/conversions/fp8_blockwise.py @@ -0,0 +1,23 @@ +"""FP8 e4m3 blockwise quantization, 128x128 blocks. Registered as ``"fp8_128x128"``. + +Dispatches between the 2D linear layer path and the 3D stacked-expert path +based on ``src.ndim``. +""" + +from __future__ import annotations + +from torch import Tensor + +from prime_rl.trainer.models.conversions import register +from prime_rl.trainer.models.fp8 import fp8_block_quantize, grouped_fp8_block_quantize + + +def fp8_128x128(src: Tensor, out: Tensor, scale_out: Tensor | None) -> None: + assert scale_out is not None, "fp8_128x128 requires a scale_out buffer" + if src.ndim == 3: + grouped_fp8_block_quantize(src, out=out, sf=scale_out) + else: + fp8_block_quantize(src, out=out, sf=scale_out) + + +register("fp8_128x128", fp8_128x128, requires_scale=True) diff --git a/src/prime_rl/trainer/models/fp8.py b/src/prime_rl/trainer/models/fp8.py index dc3411cc77..c04bf2f3b7 100644 --- a/src/prime_rl/trainer/models/fp8.py +++ b/src/prime_rl/trainer/models/fp8.py @@ -1,6 +1,12 @@ import torch from torch import Tensor +BLOCK_SIZE = 128 + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + def quantize_to_fp8_blockwise(weight: Tensor, block_size: int = 128) -> tuple[Tensor, Tensor]: """Quantize a 2D tensor to FP8 e4m3 with per-block scales.""" @@ -37,3 +43,46 @@ def quantize_to_fp8_blockwise(weight: Tensor, block_size: int = 128) -> tuple[Te quantized = blocks_fp8.permute(0, 2, 1, 3).reshape(padded_rows, padded_cols)[:rows, :cols].contiguous() return quantized, scales.float().contiguous() + + +def fp8_block_quantize( + x: Tensor, + out: Tensor | None = None, + sf: Tensor | None = None, +) -> tuple[Tensor, Tensor]: + """2D FP8 blockwise quantize. Optionally writes into preallocated ``out``/``sf``.""" + q, s = quantize_to_fp8_blockwise(x, BLOCK_SIZE) + if out is not None: + out.copy_(q) + if sf is not None: + sf.copy_(s) + return q, s + + +def grouped_fp8_block_quantize( + x: Tensor, + out: Tensor | None = None, + sf: Tensor | None = None, +) -> tuple[Tensor, Tensor]: + """3D (expert-major) FP8 blockwise quantize via per-expert loop. + + Optionally writes into preallocated ``out``/``sf``. + """ + if x.ndim != 3: + raise ValueError(f"grouped_fp8_block_quantize expects 3D, got shape={tuple(x.shape)}") + groups, rows, cols = x.shape + q_accum = torch.empty((groups, rows, cols), dtype=torch.float8_e4m3fn, device=x.device) + s_accum = torch.empty( + (groups, ceil_div(rows, BLOCK_SIZE), ceil_div(cols, BLOCK_SIZE)), + dtype=torch.float32, + device=x.device, + ) + for g in range(groups): + q_g, s_g = quantize_to_fp8_blockwise(x[g], BLOCK_SIZE) + q_accum[g] = q_g + s_accum[g] = s_g + if out is not None: + out.copy_(q_accum) + if sf is not None: + sf.copy_(s_accum) + return q_accum, s_accum diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/converting_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/converting_glm_moe_dsa.py index e8e5b311b8..3d3afcafc5 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/converting_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/converting_glm_moe_dsa.py @@ -1,6 +1,7 @@ import torch from torch import Tensor +from prime_rl.trainer.models.conversion_spec import ConversionSpec, MaybeQuantize from prime_rl.trainer.models.fp8 import quantize_to_fp8_blockwise @@ -146,6 +147,107 @@ def convert_tt_to_hf_moe(state_dict: dict[str, Tensor]): convert_tt_layer_to_hf(state_dict, i) +BASE_LAYER_CONVERSION_SPEC: tuple[ConversionSpec, ...] = ( + ConversionSpec( + "input_layernorm.weight", + ("input_layernorm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "post_attention_layernorm.weight", + ("post_attention_layernorm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "self_attn.fused_qkv_a_proj.weight", + ("self_attn.q_a_proj.weight", "self_attn.kv_a_proj_with_mqa.weight"), + ), + ConversionSpec( + "self_attn.q_a_layernorm.weight", + ("self_attn.q_a_layernorm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "self_attn.kv_a_layernorm.weight", + ("self_attn.kv_a_layernorm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec("self_attn.q_b_proj.weight", ("self_attn.q_b_proj.weight",)), + ConversionSpec("self_attn.kv_b_proj.weight", ("self_attn.kv_b_proj.weight",)), + ConversionSpec("self_attn.o_proj.weight", ("self_attn.o_proj.weight",)), + ConversionSpec("self_attn.indexer.wq_b.weight", ("self_attn.indexer.wq_b.weight",)), + ConversionSpec( + "self_attn.indexer.wk_weights_proj.weight", + ("self_attn.indexer.wk.weight", "self_attn.indexer.weights_proj.weight"), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "self_attn.indexer.k_norm.weight", + ("self_attn.indexer.k_norm.weight",), + conversion=MaybeQuantize("float32_passthrough"), + ), + ConversionSpec( + "self_attn.indexer.k_norm.bias", + ("self_attn.indexer.k_norm.bias",), + conversion=MaybeQuantize("float32_passthrough"), + ), +) + + +DENSE_LAYER_CONVERSION_SPEC: tuple[ConversionSpec, ...] = ( + ConversionSpec("mlp.gate_up_proj.weight", ("mlp.gate_proj.weight", "mlp.up_proj.weight")), + ConversionSpec("mlp.down_proj.weight", ("mlp.down_proj.weight",)), +) + + +SPARSE_LAYER_CONVERSION_SPEC: tuple[ConversionSpec, ...] = ( + ConversionSpec( + "mlp.gate.weight", + ("mlp.router.gate.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "mlp.gate.e_score_correction_bias", + ("mlp.expert_bias",), + conversion=MaybeQuantize("float32_passthrough"), + ), + ConversionSpec("mlp.experts.w13_weight", ("mlp.experts.w1", "mlp.experts.w3"), cat_dim=1), + ConversionSpec("mlp.experts.w2_weight", ("mlp.experts.w2",)), + ConversionSpec( + "mlp.shared_experts.gate_up_proj.weight", + ("mlp.shared_expert.w1", "mlp.shared_expert.w3"), + ), + ConversionSpec("mlp.shared_experts.down_proj.weight", ("mlp.shared_expert.w2",)), +) + + +NON_LAYER_CONVERSION_SPEC: tuple[ConversionSpec, ...] = ( + ConversionSpec( + "model.embed_tokens.weight", + ("model.embed_tokens.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "model.norm.weight", + ("model.norm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "lm_head.weight", + ("lm_head.weight",), + conversion=MaybeQuantize("passthrough"), + ), +) + + +CONVERSION_SPECS = { + "base_layer": BASE_LAYER_CONVERSION_SPEC, + "dense_layer": DENSE_LAYER_CONVERSION_SPEC, + "sparse_layer": SPARSE_LAYER_CONVERSION_SPEC, + "non_layer": NON_LAYER_CONVERSION_SPEC, +} + + def convert_tt_layer_to_vllm_kernel( state_dict: dict[str, Tensor], layer_idx: int, diff --git a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py index 2bb90d94ff..040a5523af 100644 --- a/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -13,8 +13,11 @@ from transformers.utils.deprecation import deprecate_kwarg from prime_rl.trainer.models.base import PreTrainedModelPrimeRL +from prime_rl.trainer.models.conversion_spec import ConversionSpec from prime_rl.trainer.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig from prime_rl.trainer.models.glm_moe_dsa.converting_glm_moe_dsa import ( + CONVERSION_SPECS, + NON_LAYER_CONVERSION_SPEC, convert_hf_layer_to_tt, convert_hf_to_tt_moe, convert_tt_layer_to_hf, @@ -167,6 +170,15 @@ def convert_layer_to_vllm_kernel( ) -> dict[str, Tensor]: return convert_tt_layer_to_vllm_kernel(state_dict, layer_idx, quantize_fp8=quantize_fp8) + def get_conversion_specs_for_layer(self, layer_idx: int) -> list[ConversionSpec]: + is_dense = layer_idx < self.config.first_k_dense_replace + tail = CONVERSION_SPECS["dense_layer"] if is_dense else CONVERSION_SPECS["sparse_layer"] + return list(CONVERSION_SPECS["base_layer"] + tail) + + @property + def non_layer_specs(self) -> tuple[ConversionSpec, ...]: + return NON_LAYER_CONVERSION_SPEC + @auto_docstring class GlmMoeDsaModel(GlmMoeDsaPreTrainedModel): diff --git a/src/prime_rl/trainer/models/qwen3_moe/converting_qwen3_moe.py b/src/prime_rl/trainer/models/qwen3_moe/converting_qwen3_moe.py index 48ccc5b941..f7368b9bd9 100644 --- a/src/prime_rl/trainer/models/qwen3_moe/converting_qwen3_moe.py +++ b/src/prime_rl/trainer/models/qwen3_moe/converting_qwen3_moe.py @@ -1,6 +1,8 @@ import torch from torch import Tensor +from prime_rl.trainer.models.conversion_spec import ConversionSpec, MaybeQuantize + def get_max_layer_num(state_dict: dict[str, Tensor]) -> int: """Get the maximum number of layers in the model.""" @@ -98,3 +100,74 @@ def convert_tt_to_hf_moe(state_dict: dict[str, Tensor]): num_layers = get_max_layer_num(state_dict) for i in range(num_layers): convert_tt_layer_to_hf(state_dict, i) + + +BASE_LAYER_CONVERSION_SPEC: tuple[ConversionSpec, ...] = ( + ConversionSpec( + "input_layernorm.weight", + ("input_layernorm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "post_attention_layernorm.weight", + ("post_attention_layernorm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "self_attn.q_norm.weight", + ("self_attn.q_norm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "self_attn.k_norm.weight", + ("self_attn.k_norm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "self_attn.qkv_proj.weight", + ("self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"), + ), + ConversionSpec("self_attn.o_proj.weight", ("self_attn.o_proj.weight",)), +) + + +DENSE_LAYER_CONVERSION_SPEC: tuple[ConversionSpec, ...] = ( + ConversionSpec("mlp.gate_up_proj.weight", ("mlp.gate_proj.weight", "mlp.up_proj.weight")), + ConversionSpec("mlp.down_proj.weight", ("mlp.down_proj.weight",)), +) + + +SPARSE_LAYER_CONVERSION_SPEC: tuple[ConversionSpec, ...] = ( + ConversionSpec( + "mlp.gate.weight", + ("mlp.router.gate.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec("mlp.experts.w13_weight", ("mlp.experts.w1", "mlp.experts.w3"), cat_dim=1), + ConversionSpec("mlp.experts.w2_weight", ("mlp.experts.w2",)), +) + +NON_LAYER_CONVERSION_SPEC: tuple[ConversionSpec, ...] = ( + ConversionSpec( + "model.embed_tokens.weight", + ("model.embed_tokens.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "model.norm.weight", + ("model.norm.weight",), + conversion=MaybeQuantize("passthrough"), + ), + ConversionSpec( + "lm_head.weight", + ("lm_head.weight",), + conversion=MaybeQuantize("passthrough"), + ), +) + +CONVERSION_SPECS = { + "base_layer": BASE_LAYER_CONVERSION_SPEC, + "dense_layer": DENSE_LAYER_CONVERSION_SPEC, + "sparse_layer": SPARSE_LAYER_CONVERSION_SPEC, + "non_layer": NON_LAYER_CONVERSION_SPEC, +} diff --git a/src/prime_rl/trainer/models/qwen3_moe/modeling_qwen3_moe.py b/src/prime_rl/trainer/models/qwen3_moe/modeling_qwen3_moe.py index 25d83f9436..05f39b4bb4 100644 --- a/src/prime_rl/trainer/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/prime_rl/trainer/models/qwen3_moe/modeling_qwen3_moe.py @@ -27,6 +27,7 @@ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from prime_rl.trainer.models.base import PreTrainedModelPrimeRL +from prime_rl.trainer.models.conversion_spec import ConversionSpec from prime_rl.trainer.models.layers.attn import ATTN_IMPL2CLASS, AttentionConfig from prime_rl.trainer.models.layers.lm_head import PrimeLmOutput from prime_rl.trainer.models.layers.mlp import MLP, MLPConfig @@ -35,6 +36,8 @@ from prime_rl.trainer.models.layers.rotary_emb import RotaryEmbedding, RotaryEmbeddingConfig from prime_rl.trainer.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig from prime_rl.trainer.models.qwen3_moe.converting_qwen3_moe import ( + CONVERSION_SPECS, + NON_LAYER_CONVERSION_SPEC, convert_hf_layer_to_tt, convert_hf_to_tt_moe, convert_tt_layer_to_hf, @@ -170,6 +173,21 @@ def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) - convert_hf_layer_to_tt(state_dict, layer_idx) return state_dict + def get_conversion_specs_for_layer(self, layer_idx: int) -> list[ConversionSpec]: + if layer_idx in self.config.mlp_only_layers: + is_dense = True + elif self.config.num_experts == 0: + is_dense = True + else: + is_dense = (layer_idx + 1) % self.config.decoder_sparse_step != 0 + + tail = CONVERSION_SPECS["dense_layer"] if is_dense else CONVERSION_SPECS["sparse_layer"] + return list(CONVERSION_SPECS["base_layer"] + tail) + + @property + def non_layer_specs(self) -> tuple[ConversionSpec, ...]: + return NON_LAYER_CONVERSION_SPEC + @auto_docstring class Qwen3MoeModel(Qwen3MoePreTrainedModel): diff --git a/src/prime_rl/trainer/models/slots.py b/src/prime_rl/trainer/models/slots.py new file mode 100644 index 0000000000..0d42e61bda --- /dev/null +++ b/src/prime_rl/trainer/models/slots.py @@ -0,0 +1,649 @@ +"""Trainer-side destination slots for NIXL weight transfer. + +Three slot types share a uniform protocol: + +* :class:`ShardedSlot` — non-expert param whose dim-0 is FSDP-shardable and + large enough to shard. The slot holds *this rank's* shard. Writes to + ``chunk[my_rank]`` on every inference peer. +* :class:`GatheredSlot` — non-expert param too small to shard or whose + shape doesn't divide. The slot holds the full tensor. Written once per + peer, round-robin across trainer ranks (``i % trainer_ws == my_rank``). +* :class:`ExpertSlot` — MoE expert param, fused per-rank into a 3D buffer + along ``cat_dim``. Each local expert is one chunk; writes target peers + that own each global expert (via vLLM's ``expert_map``). + +Each slot captures everything it needs at construction (``my_rank``, +``trainer_ws``, ``owned_global_experts``) so runtime methods only take +dynamic inputs (peers, source state_dict). Wire types +(:class:`~prime_rl.transport.wire.LayoutEntry`, +:class:`~prime_rl.transport.wire.PeerInfo`, +:class:`~prime_rl.transport.wire.WriteEntry`) live in the transport layer. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Protocol + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.tensor import DTensor + +from prime_rl.trainer.models.conversion_spec import ConversionSpec +from prime_rl.trainer.models.conversions import ConversionEntry, resolve +from prime_rl.trainer.models.fp8 import BLOCK_SIZE, ceil_div +from prime_rl.trainer.parallel_dims import ParallelDims +from prime_rl.transport.wire import LayoutEntry, PeerInfo, WriteEntry + +# Source tensors smaller than this fall out of the per-shard NIXL path and +# are gathered instead — below ~2 MiB the per-shard fragment drops under +# 32 KiB per peer and the RDMA handle overhead eats any parallelism gain. +SMALL_NON_EXPERT_BYTES = 2 * 1024 * 1024 + + +def _maybe_cast_source_for_transfer(src: Tensor, conversion: ConversionEntry) -> Tensor: + if conversion.preserve_source_dtype or not src.is_floating_point(): + return src + return src.to(torch.bfloat16) + + +def _slot_dtype(conversion: ConversionEntry, base_dtype: torch.dtype) -> torch.dtype: + if conversion.dst_dtype is not None: + return conversion.dst_dtype + if conversion.requires_scale: + return torch.float8_e4m3fn + return base_dtype + + +# --- Slot protocol --------------------------------------------------------- # + + +class Slot(Protocol): + weight: Tensor + scale: Optional[Tensor] + spec: ConversionSpec + conversion: ConversionEntry + + @property + def buffers(self) -> list[tuple[str, Tensor, int]]: + """``(buffer_key, tensor, num_chunks)`` triples for NIXL registration.""" + ... + + def convert(self, state_dict: dict[str, Tensor]) -> None: + """Pull this slot's source tensor(s) from ``state_dict`` and write into the buffers.""" + ... + + def layout_payload(self) -> list[LayoutEntry]: + """Layout entries to publish so inference can chunk its destination.""" + ... + + def build_writes(self, peers: list[PeerInfo]) -> list[WriteEntry]: + """One RDMA WRITE per ``(buffer, peer-chunk)`` this slot owns.""" + ... + + def peer_chunk_descs(self, peer: PeerInfo) -> dict[str, list[tuple[int, int, int]]]: + """Per-buffer ``(addr, size, device_id)`` tuples for ``peer``'s side. + + Returns one entry per :attr:`buffers` key. The list length matches + the peer's chunk count for that buffer (``trainer_ws`` for sharded, + ``1`` for gathered, ``len(peer.expert_map[moe_prefix])`` for + experts). Used by the transport plan to ``prep_remote`` per + (peer, buffer). + """ + ... + + +# --- Non-expert slots ------------------------------------------------------ # + + +@dataclass +class ShardedSlot: + """Non-expert slot holding one FSDP shard. Writes to every peer at chunk[my_rank].""" + + weight: Tensor + scale: Optional[Tensor] + spec: ConversionSpec + conversion: ConversionEntry + source_name: str # full source name in state_dict + slot_key: str # local + remote buffer key for the weight + scale_key: Optional[str] # local scale buffer key (per-source naming) + inference_name: str # destination name on inference side (fused dst) + inference_scale_name: Optional[str] + offset_rows: int # this source's row offset in the fused inference dst + scale_offset_rows: Optional[int] + rows: int # source's full dim-0 + scale_rows: Optional[int] + my_rank: int + trainer_ws: int + + @classmethod + def from_spec( + cls, + spec: ConversionSpec, + conversion: ConversionEntry, + prefix: str, + src_name: str, + src: Tensor, + parallel_dims: ParallelDims, + base_dtype: torch.dtype, + offset_rows: int, + scale_offset_rows: int, + ) -> "ShardedSlot": + fsdp_total = parallel_dims.dp_shard * parallel_dims.cp + if dist.is_initialized(): + mesh = parallel_dims.get_mesh("dp_shard_cp") + my_rank, trainer_ws = mesh.get_local_rank(), mesh.size() + else: + my_rank, trainer_ws = 0, 1 + src_rows = src.shape[0] + rows_per_shard = src_rows // fsdp_total + dst_dtype = _slot_dtype(conversion, base_dtype) + weight = torch.empty( + (rows_per_shard,) + tuple(src.shape[1:]), + dtype=dst_dtype, + device=src.device, + ) + slot_key = f"{prefix}.{src_name}" if prefix else src_name + scale: Optional[Tensor] = None + scale_key: Optional[str] = None + inference_scale_name: Optional[str] = None + scale_rows: Optional[int] = None + if conversion.requires_scale: + scale = torch.empty( + (ceil_div(weight.shape[0], BLOCK_SIZE), ceil_div(weight.shape[1], BLOCK_SIZE)), + dtype=torch.float32, + device=weight.device, + ) + scale_key = ConversionSpec.scale_name(slot_key, allow_direct_parameter=True) + inference_scale_name = ConversionSpec.scale_name(f"{prefix}.{spec.dst}" if prefix else spec.dst) + scale_rows = ceil_div(src_rows, BLOCK_SIZE) + return cls( + weight=weight, + scale=scale, + spec=spec, + conversion=conversion, + source_name=slot_key, + slot_key=slot_key, + scale_key=scale_key, + inference_name=f"{prefix}.{spec.dst}" if prefix else spec.dst, + inference_scale_name=inference_scale_name, + offset_rows=offset_rows, + scale_offset_rows=scale_offset_rows if conversion.requires_scale else None, + rows=src_rows, + scale_rows=scale_rows, + my_rank=my_rank, + trainer_ws=trainer_ws, + ) + + @property + def buffers(self) -> list[tuple[str, Tensor, int]]: + out: list[tuple[str, Tensor, int]] = [(self.slot_key, self.weight, 1)] + if self.scale is not None: + assert self.scale_key is not None + out.append((self.scale_key, self.scale, 1)) + return out + + def convert(self, state_dict: dict[str, Tensor]) -> None: + src = _maybe_cast_source_for_transfer(state_dict[self.source_name], self.conversion) + if isinstance(src, DTensor): + src = src.full_tensor() if self.weight.shape[0] == src.shape[0] else src.to_local() + self.conversion.fn(src, self.weight, self.scale) + + def layout_payload(self) -> list[LayoutEntry]: + entries = [ + LayoutEntry( + slot_key=self.slot_key, + inference_name=self.inference_name, + offset_rows=self.offset_rows, + rows=self.rows, + num_chunks=self.trainer_ws, + ) + ] + if self.scale is not None: + assert self.scale_key is not None and self.scale_rows is not None + assert self.inference_scale_name is not None and self.scale_offset_rows is not None + entries.append( + LayoutEntry( + slot_key=self.scale_key, + inference_name=self.inference_scale_name, + offset_rows=self.scale_offset_rows, + rows=self.scale_rows, + num_chunks=self.trainer_ws, + ) + ) + return entries + + def build_writes(self, peers: list[PeerInfo]) -> list[WriteEntry]: + out: list[WriteEntry] = [] + for peer in peers: + for buf_key, _, _ in self.buffers: + out.append( + WriteEntry( + local_buffer_key=buf_key, + local_chunk_idx=0, + peer_name=peer.agent_name, + remote_buffer_key=buf_key, + remote_chunk_idx=self.my_rank, + tag=f"per_shard:{buf_key}", + ) + ) + return out + + def peer_chunk_descs(self, peer: PeerInfo) -> dict[str, list[tuple[int, int, int]]]: + out: dict[str, list[tuple[int, int, int]]] = {} + # Weight: ``trainer_ws`` chunks along inference dst dim 0, each + # ``rows/trainer_ws`` rows wide, starting at ``offset_rows``. + weight_row_bytes = self.weight.numel() * self.weight.element_size() // self.weight.shape[0] + weight_chunk_rows = self.rows // self.trainer_ws + weight_base, _, weight_dev = peer.tensor_addrs[self.inference_name] + out[self.slot_key] = [ + ( + weight_base + (self.offset_rows + i * weight_chunk_rows) * weight_row_bytes, + weight_chunk_rows * weight_row_bytes, + weight_dev, + ) + for i in range(self.trainer_ws) + ] + if self.scale is not None: + assert self.scale_key is not None and self.scale_rows is not None + assert self.inference_scale_name is not None and self.scale_offset_rows is not None + scale_row_bytes = self.scale.numel() * self.scale.element_size() // self.scale.shape[0] + scale_chunk_rows = self.scale_rows // self.trainer_ws + scale_base, _, scale_dev = peer.tensor_addrs[self.inference_scale_name] + out[self.scale_key] = [ + ( + scale_base + (self.scale_offset_rows + i * scale_chunk_rows) * scale_row_bytes, + scale_chunk_rows * scale_row_bytes, + scale_dev, + ) + for i in range(self.trainer_ws) + ] + return out + + +@dataclass +class GatheredSlot: + """Non-expert slot holding the full tensor. Written once per peer, round-robin across trainer ranks.""" + + weight: Tensor + scale: Optional[Tensor] + spec: ConversionSpec + conversion: ConversionEntry + source_name: str + slot_key: str + scale_key: Optional[str] + inference_name: str + inference_scale_name: Optional[str] + offset_rows: int + scale_offset_rows: Optional[int] + rows: int + scale_rows: Optional[int] + my_rank: int + trainer_ws: int + + @classmethod + def from_spec( + cls, + spec: ConversionSpec, + conversion: ConversionEntry, + prefix: str, + src_name: str, + src: Tensor, + parallel_dims: ParallelDims, + base_dtype: torch.dtype, + offset_rows: int, + scale_offset_rows: int, + ) -> "GatheredSlot": + if dist.is_initialized(): + mesh = parallel_dims.get_mesh("dp_shard_cp") + my_rank, trainer_ws = mesh.get_local_rank(), mesh.size() + else: + my_rank, trainer_ws = 0, 1 + dst_dtype = _slot_dtype(conversion, base_dtype) + weight = torch.empty(tuple(src.shape), dtype=dst_dtype, device=src.device) + slot_key = f"{prefix}.{src_name}" if prefix else src_name + scale: Optional[Tensor] = None + scale_key: Optional[str] = None + inference_scale_name: Optional[str] = None + scale_rows: Optional[int] = None + src_rows = src.shape[0] + if conversion.requires_scale: + scale = torch.empty( + (ceil_div(weight.shape[0], BLOCK_SIZE), ceil_div(weight.shape[1], BLOCK_SIZE)), + dtype=torch.float32, + device=weight.device, + ) + scale_key = ConversionSpec.scale_name(slot_key, allow_direct_parameter=True) + inference_scale_name = ConversionSpec.scale_name(f"{prefix}.{spec.dst}" if prefix else spec.dst) + scale_rows = ceil_div(src_rows, BLOCK_SIZE) + return cls( + weight=weight, + scale=scale, + spec=spec, + conversion=conversion, + source_name=slot_key, + slot_key=slot_key, + scale_key=scale_key, + inference_name=f"{prefix}.{spec.dst}" if prefix else spec.dst, + inference_scale_name=inference_scale_name, + offset_rows=offset_rows, + scale_offset_rows=scale_offset_rows if conversion.requires_scale else None, + rows=src_rows, + scale_rows=scale_rows, + my_rank=my_rank, + trainer_ws=trainer_ws, + ) + + @property + def buffers(self) -> list[tuple[str, Tensor, int]]: + out: list[tuple[str, Tensor, int]] = [(self.slot_key, self.weight, 1)] + if self.scale is not None: + assert self.scale_key is not None + out.append((self.scale_key, self.scale, 1)) + return out + + def convert(self, state_dict: dict[str, Tensor]) -> None: + src = _maybe_cast_source_for_transfer(state_dict[self.source_name], self.conversion) + if isinstance(src, DTensor): + src = src.full_tensor() if self.weight.shape[0] == src.shape[0] else src.to_local() + self.conversion.fn(src, self.weight, self.scale) + + def layout_payload(self) -> list[LayoutEntry]: + entries = [ + LayoutEntry( + slot_key=self.slot_key, + inference_name=self.inference_name, + offset_rows=self.offset_rows, + rows=self.rows, + num_chunks=1, + ) + ] + if self.scale is not None: + assert self.scale_key is not None and self.scale_rows is not None + assert self.inference_scale_name is not None and self.scale_offset_rows is not None + entries.append( + LayoutEntry( + slot_key=self.scale_key, + inference_name=self.inference_scale_name, + offset_rows=self.scale_offset_rows, + rows=self.scale_rows, + num_chunks=1, + ) + ) + return entries + + def build_writes(self, peers: list[PeerInfo]) -> list[WriteEntry]: + out: list[WriteEntry] = [] + for i, peer in enumerate(peers): + if i % self.trainer_ws != self.my_rank: + continue + for buf_key, _, _ in self.buffers: + out.append( + WriteEntry( + local_buffer_key=buf_key, + local_chunk_idx=0, + peer_name=peer.agent_name, + remote_buffer_key=buf_key, + remote_chunk_idx=0, + tag=f"gather:{buf_key}", + ) + ) + return out + + def peer_chunk_descs(self, peer: PeerInfo) -> dict[str, list[tuple[int, int, int]]]: + out: dict[str, list[tuple[int, int, int]]] = {} + # Single chunk on the peer side, covering this slot's full row range + # (``rows`` rows starting at ``offset_rows``). + weight_row_bytes = self.weight.numel() * self.weight.element_size() // self.weight.shape[0] + weight_base, _, weight_dev = peer.tensor_addrs[self.inference_name] + out[self.slot_key] = [ + ( + weight_base + self.offset_rows * weight_row_bytes, + self.rows * weight_row_bytes, + weight_dev, + ) + ] + if self.scale is not None: + assert self.scale_key is not None and self.scale_rows is not None + assert self.inference_scale_name is not None and self.scale_offset_rows is not None + scale_row_bytes = self.scale.numel() * self.scale.element_size() // self.scale.shape[0] + scale_base, _, scale_dev = peer.tensor_addrs[self.inference_scale_name] + out[self.scale_key] = [ + ( + scale_base + self.scale_offset_rows * scale_row_bytes, + self.scale_rows * scale_row_bytes, + scale_dev, + ) + ] + return out + + +# --- Expert slot ----------------------------------------------------------- # + + +@dataclass +class ExpertSlot: + """Fused stacked-expert slot. One 3D buffer holding ``num_local`` experts. + + Each local expert is one chunk; writes go per-(local, peer) pair filtered + by the peer's ``expert_map`` (only peers that own a global expert receive + a WRITE for it). + """ + + weight: Tensor # (num_local, cat_dim_size, hidden) + scale: Optional[Tensor] # (num_local, ceil/128, ceil/128) + spec: ConversionSpec + conversion: ConversionEntry + source_names: tuple[str, ...] + slot_key: str # also serves as the inference destination name + scale_key: Optional[str] + moe_prefix: str # e.g. "model.layers.0.mlp.experts" — keys peer.expert_map + owned_global_experts: list[int] # index i is this weight's chunk i + cat_dim: int + + @classmethod + def from_spec( + cls, + spec: ConversionSpec, + conversion: ConversionEntry, + prefix: str, + state_dict: dict[str, Tensor], + parallel_dims: ParallelDims, + base_dtype: torch.dtype, + ) -> "ExpertSlot": + if parallel_dims.ep_enabled: + ep_mesh = parallel_dims.get_mesh("ep") + fsdp_mesh = parallel_dims.get_mesh("dp_shard_mod_ep") + ep_size, ep_rank = ep_mesh.size(), ep_mesh.get_local_rank() + fsdp_size, fsdp_rank = fsdp_mesh.size(), fsdp_mesh.get_local_rank() + else: + ep_size, ep_rank, fsdp_size, fsdp_rank = 1, 0, 1, 0 + + sample = state_dict[f"{prefix}.{spec.sources[0]}"] + local_sample = sample.to_local() if isinstance(sample, DTensor) else sample + num_local_experts = local_sample.shape[0] + if isinstance(sample, DTensor): + total_experts = sample.shape[0] + assert num_local_experts * fsdp_size * ep_size == total_experts, ( + f"EP partition mismatch for {spec.dst!r} at {prefix}: " + f"local={num_local_experts} * fsdp={fsdp_size} * ep={ep_size} " + f"!= total={total_experts}" + ) + num_experts_per_ep = num_local_experts * fsdp_size + base = ep_rank * num_experts_per_ep + fsdp_rank * num_local_experts + owned_global_experts = list(range(base, base + num_local_experts)) + + src_local_shapes = [] + for name in spec.sources: + t = state_dict[f"{prefix}.{name}"] + src_local_shapes.append((t.to_local() if isinstance(t, DTensor) else t).shape) + dst_shape = list(src_local_shapes[0]) + dst_shape[spec.cat_dim] = sum(sh[spec.cat_dim] for sh in src_local_shapes) + device = local_sample.device + dst_dtype = _slot_dtype(conversion, base_dtype) + weight = torch.empty(tuple(dst_shape), dtype=dst_dtype, device=device) + slot_key = f"{prefix}.{spec.dst}" if prefix else spec.dst + moe_prefix = slot_key.rsplit(".", 1)[0] + scale: Optional[Tensor] = None + scale_key: Optional[str] = None + if conversion.requires_scale: + scale = torch.empty( + ( + weight.shape[0], + ceil_div(weight.shape[1], BLOCK_SIZE), + ceil_div(weight.shape[2], BLOCK_SIZE), + ), + dtype=torch.float32, + device=device, + ) + scale_key = ConversionSpec.scale_name(slot_key) + return cls( + weight=weight, + scale=scale, + spec=spec, + conversion=conversion, + source_names=tuple(f"{prefix}.{name}" for name in spec.sources), + slot_key=slot_key, + scale_key=scale_key, + moe_prefix=moe_prefix, + owned_global_experts=owned_global_experts, + cat_dim=spec.cat_dim, + ) + + @property + def num_local_experts(self) -> int: + return self.weight.shape[0] + + @property + def buffers(self) -> list[tuple[str, Tensor, int]]: + n = self.num_local_experts + out: list[tuple[str, Tensor, int]] = [(self.slot_key, self.weight, n)] + if self.scale is not None: + assert self.scale_key is not None + out.append((self.scale_key, self.scale, n)) + return out + + def convert(self, state_dict: dict[str, Tensor]) -> None: + srcs = [] + for name in self.source_names: + t = _maybe_cast_source_for_transfer(state_dict[name], self.conversion) + srcs.append(t.to_local() if isinstance(t, DTensor) else t) + tensor = srcs[0] if len(srcs) == 1 else torch.cat(srcs, dim=self.cat_dim) + self.conversion.fn(tensor, self.weight, self.scale) + + def layout_payload(self) -> list[LayoutEntry]: + # Expert slots route via peer.expert_map; no LayoutEntry needed. + return [] + + def build_writes(self, peers: list[PeerInfo]) -> list[WriteEntry]: + out: list[WriteEntry] = [] + for local_idx, global_id in enumerate(self.owned_global_experts): + for peer in peers: + peer_experts = peer.expert_map.get(self.moe_prefix, []) + if global_id not in peer_experts: + continue + remote_idx = peer_experts.index(global_id) + for buf_key, _, _ in self.buffers: + out.append( + WriteEntry( + local_buffer_key=buf_key, + local_chunk_idx=local_idx, + peer_name=peer.agent_name, + remote_buffer_key=buf_key, + remote_chunk_idx=remote_idx, + tag=f"expert:{buf_key}:E{global_id}", + ) + ) + return out + + def peer_chunk_descs(self, peer: PeerInfo) -> dict[str, list[tuple[int, int, int]]]: + # Peer's chunks = peer's local experts (one per ``expert_map[moe_prefix]`` entry), + # each one global-expert-sized row of the 3D buffer. + peer_local_experts = len(peer.expert_map.get(self.moe_prefix, [])) + out: dict[str, list[tuple[int, int, int]]] = {} + + per_expert_bytes_w = self.weight.numel() * self.weight.element_size() // self.weight.shape[0] + weight_base, _, weight_dev = peer.tensor_addrs[self.slot_key] + out[self.slot_key] = [ + (weight_base + i * per_expert_bytes_w, per_expert_bytes_w, weight_dev) for i in range(peer_local_experts) + ] + if self.scale is not None: + assert self.scale_key is not None + per_expert_bytes_s = self.scale.numel() * self.scale.element_size() // self.scale.shape[0] + scale_base, _, scale_dev = peer.tensor_addrs[self.scale_key] + out[self.scale_key] = [ + (scale_base + i * per_expert_bytes_s, per_expert_bytes_s, scale_dev) for i in range(peer_local_experts) + ] + return out + + +# --- Builders -------------------------------------------------------------- # + + +def build_slots_for_conversion_spec( + spec: ConversionSpec, + *, + prefix: str, + state_dict: dict[str, Tensor], + parallel_dims: ParallelDims, + default_conversion: str, + base_dtype: torch.dtype, +) -> list[Slot]: + """Instantiate every slot this spec produces at ``prefix``. + + Expert specs always yield one :class:`ExpertSlot`. Non-expert specs + yield one slot per source: :class:`ShardedSlot` when dim 0 divides + ``dp_shard*cp`` (and the shard size is FP8-block-aligned for quantized + sources) and the source is large enough to amortize RDMA overhead; + otherwise :class:`GatheredSlot`. Fused destinations may have one + source land sharded and another gathered. + """ + conversion = resolve(spec.conversion.conversion_type, default_conversion) + if spec.is_expert_spec: + return [ + ExpertSlot.from_spec( + spec, + conversion, + prefix=prefix, + state_dict=state_dict, + parallel_dims=parallel_dims, + base_dtype=base_dtype, + ) + ] + + fsdp_total = parallel_dims.dp_shard * parallel_dims.cp + slots: list[Slot] = [] + row_off = 0 + scale_row_off = 0 + for src_name in spec.sources: + full_src = f"{prefix}.{src_name}" if prefix else src_name + raw = state_dict[full_src] + # Pass the RAW tensor (possibly DTensor) to from_spec — it needs the + # global shape[0] to compute rows_per_shard = global_rows // fsdp_total. + # Dispatch uses the global shape too: divisibility, FP8-block alignment, + # and size threshold are all properties of the full (unfragmented) tensor. + src_rows = raw.shape[0] + per_shard = ( + src_rows % fsdp_total == 0 + and (not conversion.requires_scale or (src_rows // fsdp_total) % BLOCK_SIZE == 0) + and raw.numel() * raw.element_size() >= SMALL_NON_EXPERT_BYTES + ) + cls = ShardedSlot if per_shard else GatheredSlot + slots.append( + cls.from_spec( + spec, + conversion, + prefix=prefix, + src_name=src_name, + src=raw, + parallel_dims=parallel_dims, + base_dtype=base_dtype, + offset_rows=row_off, + scale_offset_rows=scale_row_off, + ) + ) + row_off += raw.shape[0] + if conversion.requires_scale: + scale_row_off += ceil_div(raw.shape[0], BLOCK_SIZE) + return slots diff --git a/src/prime_rl/trainer/rl/broadcast/__init__.py b/src/prime_rl/trainer/rl/broadcast/__init__.py index e419bda2f2..d3883cbd1f 100644 --- a/src/prime_rl/trainer/rl/broadcast/__init__.py +++ b/src/prime_rl/trainer/rl/broadcast/__init__.py @@ -9,11 +9,19 @@ def setup_weight_broadcast( - output_dir: Path, config: WeightBroadcastConfig, lora_config: LoRAConfig | None = None + output_dir: Path, + config: WeightBroadcastConfig, + lora_config: LoRAConfig | None = None, + parallel_dims=None, ) -> WeightBroadcast: if config.type == "nccl": return NCCLWeightBroadcast(output_dir, config, torch.cuda.current_device()) elif config.type == "filesystem": return FileSystemWeightBroadcast(output_dir, config, lora_config) + elif config.type == "nixl_mx": + from prime_rl.trainer.rl.broadcast.nixl_mx import NIXLMxWeightBroadcast + + assert parallel_dims is not None, "nixl_mx requires parallel_dims" + return NIXLMxWeightBroadcast(output_dir, config, parallel_dims) else: raise ValueError(f"Invalid weight broadcast type: {config.type}") diff --git a/src/prime_rl/trainer/rl/broadcast/nixl_mx.py b/src/prime_rl/trainer/rl/broadcast/nixl_mx.py new file mode 100644 index 0000000000..eb31d8b770 --- /dev/null +++ b/src/prime_rl/trainer/rl/broadcast/nixl_mx.py @@ -0,0 +1,179 @@ +"""Broadcast weights into the inference engine via NIXL + Model Express. + +Thin lifecycle wrapper around :class:`TrainerPublisher` + +:class:`TransportPlan`. Slot allocation and the MX rendezvous are deferred +to the first :meth:`broadcast_weights` call because the trainer model is +not available at ``setup_weight_broadcast`` time. + +HSDP: when ``dp_replicate > 1`` only the primary replica (``dp_replicate +rank 0``) participates. Non-primary replicas hold bit-identical weights +so a second copy over the wire would be pure waste; they barrier-sync +to stay in lockstep. +""" + +from __future__ import annotations + +import time +from pathlib import Path +from typing import Any + +import msgspec +import torch +import torch.distributed as dist +import torch.nn as nn +from modelexpress import MxClient, p2p_pb2 +from transformers import AutoConfig + +from prime_rl.configs.trainer import NIXLMxWeightBroadcastConfig +from prime_rl.trainer.models import PreTrainedModelPrimeRL +from prime_rl.trainer.models.conversions import select_default_conversion +from prime_rl.trainer.parallel_dims import ParallelDims +from prime_rl.trainer.rl.broadcast.base import WeightBroadcast +from prime_rl.trainer.runs import get_multi_run_manager +from prime_rl.trainer.utils import get_world +from prime_rl.transport.classic_cuda_pool import classic_cuda_alloc +from prime_rl.transport.mx_rendezvous import MxRendezvous +from prime_rl.transport.nixl_agent import NixlAgentWrapper, make_agent_name, pin_ucx_rail +from prime_rl.transport.transport_plan import TransportPlan +from prime_rl.transport.wire import RendezvousPayload + + +class NIXLMxWeightBroadcast(WeightBroadcast): + """Broadcast weights into the inference engine via NIXL (zero-copy RDMA).""" + + def __init__( + self, + output_dir: Path, + config: NIXLMxWeightBroadcastConfig, + parallel_dims: ParallelDims, + ) -> None: + super().__init__(output_dir) + self.config = config + self.world = get_world() + self.parallel_dims = parallel_dims + + if self.is_primary_hsdp_rank: + pin_ucx_rail(torch.cuda.current_device()) + self.nixl_agent = NixlAgentWrapper(name=make_agent_name("trainer", self.world.rank)) + + self.is_initialized = False + + self._multi_run_manager = get_multi_run_manager() + self._flush_every = 100 + + @property + def is_primary_hsdp_rank(self) -> bool: + if self.parallel_dims.dp_replicate_enabled: + return self.parallel_dims.get_mesh("dp_replicate").get_local_rank() == 0 + else: + return True + + def register_slot_buffers_with_nixl(self) -> None: + for slot in self.model_slots: + for _, tensor, _ in slot.buffers: + self.nixl_agent.register_tensor(tensor) + + def publish_metadata(self) -> None: + """This method creates a list of tensor descriptors for each slot buffer and publishes them to the rendezvous.""" + descriptors: list[p2p_pb2.TensorDescriptor] = [] + for slot in self.model_slots: + for buf_key, tensor, _ in slot.buffers: + descriptors.append(self.nixl_agent.make_tensor_descriptor(buf_key, tensor)) + + layout = [] + for slot in self.model_slots: + layout.extend(slot.layout_payload()) + + payload = RendezvousPayload( + agent_metadata=self.nixl_agent.get_metadata(), + agent_name=self.nixl_agent.name, + layout=layout, + ) + self.rendezvous.publish( + nixl_metadata=msgspec.msgpack.encode(payload), + tensors=descriptors, + ) + + def get_worker_metadata(self) -> list[p2p_pb2.WorkerMetadata]: + peer_refs = self.rendezvous.wait_for_peers( + status=p2p_pb2.SOURCE_STATUS_READY, + timeout=self.config.timeout, + poll_interval=1.0, + ) + + return [self.rendezvous.fetch_peer(ref) for ref in peer_refs] + + def lazy_init(self, model: PreTrainedModelPrimeRL) -> None: + """Build publisher + transport plan on first call (needs the live model).""" + if self.is_initialized: + return + + hf_config = AutoConfig.from_pretrained(self.config.inference_model_name) + default_conversion = select_default_conversion(self.config.inference_model_name) + + with classic_cuda_alloc(): + self.model_slots = model.build_slots(self.parallel_dims, default_conversion, hf_config.torch_dtype) + + self.rendezvous = MxRendezvous( + client=MxClient(server_url=f"{self.config.host}:{self.config.port}"), + role="trainer", + rank=self.world.rank, + peer_world_size=self.config.inference_world_size, + model_name=self.config.inference_model_name, + ) + self.register_slot_buffers_with_nixl() + self.publish_metadata() + + self.transport_plan = TransportPlan( + agent=self.nixl_agent, peer_metadata=self.get_worker_metadata(), slots=self.model_slots + ) + + self.is_initialized = True + + def drain(self, handles: list[tuple[Any, str]]) -> None: + if not handles: + return + for h, tag in handles: + self.nixl_agent.wait(h, context=tag) + + @torch.no_grad() + def broadcast_weights(self, model: nn.Module, step: int) -> None: + if self.is_primary_hsdp_rank: + # Try to initialize the transport plan if we haven't already, signal the orchestrator that we are ready to push by setting the status to INITIALIZING + self.lazy_init(model) + self.rendezvous.set_status(p2p_pb2.SOURCE_STATUS_INITIALIZING) + + if self.world.is_master: + for idx in self._multi_run_manager.used_idxs: + if self._multi_run_manager.ready_to_update[idx]: + self._multi_run_manager.ready_to_update[idx] = False + self.rendezvous.wait_for_all_peers_ready(role="orchestrator", timeout=self.config.timeout) + + dist.barrier() + + if self.is_primary_hsdp_rank: + start = time.perf_counter() + + handles = [] + for local_prep, remote_prep, entry in self.transport_plan.prepare_writes( + model.state_dict(), self.model_slots + ): + # Post the write to the NIXL agent + handle = self.nixl_agent.post_write( + local_prep=local_prep, + local_idx=entry.local_chunk_idx, + remote_prep=remote_prep, + remote_idx=entry.remote_chunk_idx, + ) + handles.append((handle, entry.tag)) + if len(handles) % self._flush_every == 0: + self.drain(handles) + handles.clear() + + self.drain(handles) + + # Signal the orchestrator that we are ready to push by setting the status to READY + self.rendezvous.set_status(p2p_pb2.SOURCE_STATUS_READY) + self.logger.debug(f"NIXL+MX push completed in {time.perf_counter() - start:.2f}s") + + dist.barrier() diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 4e75111907..b014999ce2 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -182,7 +182,9 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: logger.info("Skipping weight broadcast setup (fake data mode)") else: logger.info(f"Initializing weight broadcast ({config.weight_broadcast})") - weight_broadcast = setup_weight_broadcast(config.output_dir, config.weight_broadcast, config.model.lora) + weight_broadcast = setup_weight_broadcast( + config.output_dir, config.weight_broadcast, config.model.lora, parallel_dims=parallel_dims + ) if parallel_dims.cp_enabled: cp_group = parallel_dims.world_mesh["cp"].get_group() diff --git a/src/prime_rl/transport/classic_cuda_pool.py b/src/prime_rl/transport/classic_cuda_pool.py new file mode 100644 index 0000000000..9600a89ce5 --- /dev/null +++ b/src/prime_rl/transport/classic_cuda_pool.py @@ -0,0 +1,100 @@ +"""Classic ``cudaMalloc``-backed CUDA MemPool for NIXL-registered buffers. + +With ``PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"`` (the default in +prime-rl trainer SLURM templates) PyTorch's caching allocator hands out +VMM-backed (``cuMemCreate`` + ``cuMemMap``) virtual ranges. +``ibv_reg_mr`` on such ranges succeeds, but the mlx5 HCA's MMU walk at +WRITE time completes with ``syndrome 0x4`` ("Local protection"), because +``nvidia_peermem``'s ``get_pages`` cannot pin a VA that spans multiple +``cuMemCreate`` handles. UCX tears the endpoint down and NIXL surfaces +it as ``REMOTE_DISCONNECT``. + +Tensors handed to ``nixl_agent.register_memory`` must therefore come from +a classic, contiguous ``cudaMalloc`` block. This module exposes a +``MemPool`` backed by a ``CUDAPluggableAllocator`` calling ``cudaMalloc`` +/ ``cudaFree`` directly, plus :func:`classic_cuda_alloc` to scope +specific allocations into the pool. Everything else in the process keeps +using the default (expandable-segments) caching allocator. +""" + +from __future__ import annotations + +import ctypes +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator + +# TileLang ships a libcudart stub that proxies to the real CUDA runtime via +# dlsym(RTLD_DEFAULT, ...). If the stub's own symbols are found first +# (nothing has loaded the real libcudart globally yet) its self-check fails +# and the stub aborts — which is what we hit the moment we enter the +# classic-cudaMalloc MemPool. Preloading the real library with RTLD_GLOBAL +# makes dlsym find it first. +try: + ctypes.CDLL("libcudart.so", mode=ctypes.RTLD_GLOBAL) +except OSError: + pass + +import torch # noqa: E402 +from torch.utils.cpp_extension import load_inline # noqa: E402 + +_SOURCE = r""" +#include +#include +extern "C" { +void* prime_rl_classic_alloc(ptrdiff_t size, int device, void* stream) { + (void) stream; + int prev = -1; + cudaGetDevice(&prev); + cudaSetDevice(device); + void* ptr = nullptr; + cudaError_t err = cudaMalloc(&ptr, (size_t) size); + if (prev >= 0) cudaSetDevice(prev); + if (err != cudaSuccess) return nullptr; + return ptr; +} +void prime_rl_classic_free(void* ptr, ptrdiff_t size, int device, void* stream) { + (void) size; (void) stream; + int prev = -1; + cudaGetDevice(&prev); + cudaSetDevice(device); + cudaFree(ptr); + if (prev >= 0) cudaSetDevice(prev); +} +} +""" + +_pool: "torch.cuda.MemPool | None" = None +_allocator_wrapper: "torch.cuda.memory.CUDAPluggableAllocator | None" = None + + +def _get_pool() -> "torch.cuda.MemPool": + global _pool, _allocator_wrapper + if _pool is not None: + return _pool + module = load_inline( + name="prime_rl_classic_cuda_alloc", + cpp_sources=[_SOURCE], + functions=[], + extra_cflags=["-O2"], + with_cuda=True, + ) + so_path = Path(module.__file__) + _allocator_wrapper = torch.cuda.memory.CUDAPluggableAllocator( + str(so_path), "prime_rl_classic_alloc", "prime_rl_classic_free" + ) + _pool = torch.cuda.MemPool(_allocator_wrapper.allocator()) + return _pool + + +@contextmanager +def classic_cuda_alloc() -> Iterator[None]: + """Scope CUDA allocations into the classic-``cudaMalloc`` MemPool. + + No-op outside CUDA. Use when the resulting tensor's address must be a + contiguous ``cudaMalloc`` block — currently only the NIXL-registered + slot buffers. + """ + pool = _get_pool() + with torch.cuda.use_mem_pool(pool): + yield diff --git a/src/prime_rl/transport/mx_rendezvous.py b/src/prime_rl/transport/mx_rendezvous.py new file mode 100644 index 0000000000..817c1ab7dc --- /dev/null +++ b/src/prime_rl/transport/mx_rendezvous.py @@ -0,0 +1,185 @@ +"""Per-rank rendezvous client over Model Express. + +Each worker in a prime-rl run (trainer rank or inference vLLM worker) +constructs one :class:`MxRendezvous`, publishes its NIXL agent metadata +plus tensor descriptors, then blocks until the counterpart role is fully +visible. The class is intentionally thin: it owns identity construction +(role baked into ``SourceIdentity.extra_parameters`` so trainer/inference +hash to different ``mx_source_id``s) and the polling loop, and delegates +all gRPC to ``modelexpress.MxClient``. +""" + +from __future__ import annotations + +import time +import uuid +from dataclasses import dataclass +from typing import Iterable, Literal + +from modelexpress import p2p_pb2 +from modelexpress.client import MxClient + +Role = Literal["trainer", "inference", "orchestrator"] + + +@dataclass +class MxRendezvous: + """One rendezvous session per (role, rank). + + Attributes: + client: A connected :class:`modelexpress.client.MxClient`. + role: ``"trainer"`` or ``"inference"``. Recorded in + ``SourceIdentity.extra_parameters["role"]`` so the two roles + hash to different ``mx_source_id``s on the server. + rank: This worker's rank within its role. + peer_world_size: Number of workers expected on the counterpart + role. :meth:`wait_for_peers` blocks until at least this many + are visible. + model_name: Inference model identifier (e.g., + ``"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8"``). + worker_id: Unique handle for this worker, defaulting to a fresh + UUID. Two ranks must NOT share a ``worker_id``. + """ + + client: MxClient + role: Role + rank: int + peer_world_size: int + model_name: str + worker_id: str = "" + + def __post_init__(self) -> None: + if not self.worker_id: + self.worker_id = str(uuid.uuid4()) + self._mx_source_id: str | None = None + + @property + def peer_role(self) -> Role: + if self.role == "trainer": + return "inference" + if self.role == "orchestrator": + return "trainer" + return "trainer" + + @property + def mx_source_id(self) -> str: + """The mx_source_id assigned by the server. Set after :meth:`publish`.""" + if self._mx_source_id is None: + raise RuntimeError("publish() must be called before mx_source_id is available") + return self._mx_source_id + + def _identity(self, role: Role) -> p2p_pb2.SourceIdentity: + return p2p_pb2.SourceIdentity( + mx_version="0.3.0", + mx_source_type=p2p_pb2.MX_SOURCE_TYPE_WEIGHTS, + model_name=self.model_name, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_VLLM, + dtype="bfloat16", + extra_parameters={"role": role}, + ) + + def publish( + self, + *, + nixl_metadata: bytes = b"", + tensors: Iterable[p2p_pb2.TensorDescriptor] = (), + ) -> str: + """Publish this worker's metadata. Returns the assigned ``mx_source_id``.""" + worker = p2p_pb2.WorkerMetadata( + worker_rank=self.rank, + nixl_metadata=nixl_metadata, + tensors=list(tensors), + ) + self._mx_source_id = self.client.publish_metadata(self._identity(self.role), worker, self.worker_id) + return self._mx_source_id + + def wait_for_peers( + self, + *, + status: int | None = None, + timeout: float = 1200.0, + poll_interval: float = 1.0, + ) -> list[p2p_pb2.SourceInstanceRef]: + """Block until ``peer_world_size`` peers of the counterpart role are visible. + + Args: + status: If set, only count peers in this :class:`p2p_pb2.SourceStatus`. + timeout: Wall-clock seconds to wait before raising :class:`TimeoutError`. + poll_interval: Seconds between ``ListSources`` polls. + """ + import logging + + _log = logging.getLogger("prime_rl.transport.mx_rendezvous") + deadline = time.monotonic() + timeout + peer_id = self._identity(self.peer_role) + _logged = False + while True: + resp = self.client.list_sources(peer_id, status_filter=status) + if not _logged: + all_resp = self.client.list_sources(peer_id) + _log.info( + f"wait_for_peers: role={self.peer_role} need={self.peer_world_size} " + f"found_with_status={len(resp.instances)} found_any={len(all_resp.instances)} " + f"status_filter={status} model={peer_id.model_name}" + ) + _logged = True + if len(resp.instances) >= self.peer_world_size: + return list(resp.instances) + if time.monotonic() >= deadline: + raise TimeoutError( + f"timed out after {timeout}s waiting for {self.peer_world_size} " + f"{self.peer_role!r} peers (saw {len(resp.instances)})" + ) + time.sleep(poll_interval) + + def wait_for_all_peers_ready( + self, + *, + role: Role | None = None, + status: int = p2p_pb2.SOURCE_STATUS_READY, + timeout: float = 1200.0, + poll_interval: float = 0.05, + ) -> list[p2p_pb2.SourceInstanceRef]: + """Discover peer count from MX, then block until ALL of them reach ``status``. + + Unlike :meth:`wait_for_peers` (which requires a pre-known + ``peer_world_size``), this method first counts how many peer-role + entries exist in MX (any status) and uses that count as the target. + Each side publishes one entry per rank, so the count equals the + peer's world size — no config plumbing needed. + """ + target_role = role or self.peer_role + peer_id = self._identity(target_role) + deadline = time.monotonic() + timeout + + peer_count = 0 + while peer_count == 0: + peer_count = len(self.client.list_sources(peer_id).instances) + if peer_count == 0: + if time.monotonic() >= deadline: + raise TimeoutError(f"timed out waiting for {target_role!r} peers to appear in MX") + time.sleep(poll_interval) + + while True: + matched = self.client.list_sources(peer_id, status_filter=status) + if len(matched.instances) >= peer_count: + return list(matched.instances) + if time.monotonic() >= deadline: + raise TimeoutError( + f"timed out after {timeout}s waiting for {peer_count} " + f"{target_role!r} peers to reach status {status} (saw {len(matched.instances)})" + ) + time.sleep(poll_interval) + + def fetch_peer(self, ref: p2p_pb2.SourceInstanceRef) -> p2p_pb2.WorkerMetadata: + """Fetch full :class:`WorkerMetadata` for one peer ref returned by + :meth:`wait_for_peers`. + """ + resp = self.client.get_metadata(ref.mx_source_id, ref.worker_id) + if not resp.found: + raise LookupError(f"peer worker {ref.worker_id!r} not found at {ref.mx_source_id}") + return resp.worker + + def set_status(self, status: int) -> None: + """Update this worker's lifecycle status. Requires :meth:`publish` first.""" + self.client.update_status(self.mx_source_id, self.worker_id, self.rank, status) diff --git a/src/prime_rl/transport/nixl_agent.py b/src/prime_rl/transport/nixl_agent.py new file mode 100644 index 0000000000..4178e1f45c --- /dev/null +++ b/src/prime_rl/transport/nixl_agent.py @@ -0,0 +1,144 @@ +"""Thin wrapper around the NIXL agent for prime-rl. + +Covers the agent lifecycle (register tensors, get serialized metadata, +build :class:`p2p_pb2.TensorDescriptor` protos for MX) and the RDMA +primitives used by :class:`prime_rl.transport.transport_plan.TransportPlan` +(``add_remote_agent``, sub-range ``prep_xfer_dlist``, WRITE post, busy-wait). + +``nixl_cu13`` is imported lazily so the module loads on machines without +NIXL installed; only construction of :class:`NixlAgentWrapper` requires it. +""" + +from __future__ import annotations + +import os +import socket +import time +from typing import Any, Sequence + +from modelexpress import p2p_pb2 +from torch import Tensor + + +class NixlAgentWrapper: + """One per process. Owns a NIXL agent and its registered memory.""" + + def __init__(self, name: str, backends: Sequence[str] = ("UCX",)) -> None: + from nixl_cu13._api import nixl_agent, nixl_agent_config # type: ignore + + self.name = name + self.backends: list[str] = list(backends) + self._agent = nixl_agent(name, nixl_agent_config(backends=self.backends)) + + # --- registration / metadata -------------------------------------------- # + + def register_tensor(self, tensor: Tensor) -> None: + """Pin the tensor's device memory for RDMA. Idempotent per tensor.""" + self._agent.register_memory(tensor, backends=self.backends) + + def get_metadata(self) -> bytes: + """Serialized agent metadata. A peer feeds these bytes into + :meth:`add_remote_agent` to address this agent. + """ + return self._agent.get_agent_metadata() + + def make_tensor_descriptor(self, name: str, tensor: Tensor) -> p2p_pb2.TensorDescriptor: + """Build a :class:`p2p_pb2.TensorDescriptor` pointing at this agent's + registered memory for ``tensor``. Caller must have already passed + ``tensor`` through :meth:`register_tensor`. + """ + return p2p_pb2.TensorDescriptor( + name=name, + addr=tensor.data_ptr(), + size=tensor.numel() * tensor.element_size(), + device_id=tensor.device.index if tensor.device.type == "cuda" else 0, + dtype=str(tensor.dtype).removeprefix("torch."), + ) + + # --- transport primitives ---------------------------------------------- # + + def add_remote_agent(self, peer_metadata: bytes) -> str: + """Import a peer's serialized agent metadata. Returns the peer's agent name.""" + return self._agent.add_remote_agent(peer_metadata) + + def make_connection(self, peer_name: str) -> None: + """Eagerly establish the UCX connection to a peer. + + Without this, the first WRITE to each peer includes the full UCX + endpoint creation + RDMA handshake overhead (~seconds per peer). + """ + self._agent.make_connection(peer_name) + + def prep_local(self, descs: Sequence[tuple[int, int, int]]) -> Any: + """Prepare a local-side dlist (no peer binding). + + Each entry is a ``(addr, size, device_id)`` triple within memory + already registered on this agent. + """ + return self._agent.prep_xfer_dlist( + agent_name="", xfer_list=list(descs), mem_type="cuda", backends=self.backends + ) + + def prep_remote(self, peer_name: str, descs: Sequence[tuple[int, int, int]]) -> Any: + """Prepare a remote-side dlist bound to ``peer_name``. + + ``peer_name`` must have been imported via :meth:`add_remote_agent`; + each entry's ``(addr, size, device_id)`` must fall within an MR the + peer registered. + """ + return self._agent.prep_xfer_dlist( + agent_name=peer_name, xfer_list=list(descs), mem_type="cuda", backends=self.backends + ) + + def post_write(self, *, local_prep: Any, local_idx: int, remote_prep: Any, remote_idx: int) -> Any: + """Post a single WRITE: local chunk ``local_idx`` → remote chunk ``remote_idx``.""" + handle = self._agent.make_prepped_xfer( + operation="WRITE", + local_xfer_side=local_prep, + local_indices=[local_idx], + remote_xfer_side=remote_prep, + remote_indices=[remote_idx], + backends=self.backends, + ) + state = self._agent.transfer(handle) + if state in ("ERR", "ERROR", "FAIL"): + raise RuntimeError(f"nixl WRITE post returned state {state}") + return handle + + def wait(self, handle: Any, *, context: str = "") -> None: + """Busy-poll a transfer handle to completion. Raises on error states.""" + while True: + state = self._agent.check_xfer_state(handle) + if state in ("DONE", "SUCCESS"): + self._agent.release_xfer_handle(handle) + return + if state in ("ERR", "ERROR", "FAIL"): + self._agent.release_xfer_handle(handle) + raise RuntimeError(f"nixl transfer ended state={state} context={context!r}") + time.sleep(0.0005) + + +def make_agent_name(role: str, global_rank: int) -> str: + return f"{role}-{socket.gethostname()}-r{global_rank}" + + +def pin_ucx_rail(local_rank: int) -> None: + """Set per-rank UCX env vars before the NIXL agent is created. + + NIC pinning is delegated to Model Express's + :func:`modelexpress.ucx_utils.apply_nic_pin_for_device` (sysfs-based + topology probe, rate-filtered, bond-aware, load-balanced); gated on + ``MX_RDMA_NIC_PIN`` env var (default: off). UCX transport defaults + are set here. + + Call once per process before constructing :class:`NixlAgentWrapper`. + """ + from modelexpress.ucx_utils import apply_nic_pin_for_device + + apply_nic_pin_for_device(local_rank) + os.environ.setdefault("UCX_TLS", "rc_mlx5,ud,cuda_copy") + os.environ.setdefault("UCX_IB_GPU_DIRECT_RDMA", "y") + os.environ.setdefault("UCX_RNDV_SCHEME", "put_zcopy") + os.environ.setdefault("UCX_RNDV_THRESH", "8192") + os.environ.setdefault("UCX_MEMTYPE_CACHE", "n") + os.environ.setdefault("UCX_WARN_UNUSED_ENV_VARS", "n") diff --git a/src/prime_rl/transport/transport_plan.py b/src/prime_rl/transport/transport_plan.py new file mode 100644 index 0000000000..f25ec5dd42 --- /dev/null +++ b/src/prime_rl/transport/transport_plan.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any + +import msgspec +import torch +from modelexpress import p2p_pb2 +from torch import Tensor + +from prime_rl.trainer.models.slots import Slot +from prime_rl.transport.nixl_agent import NixlAgentWrapper +from prime_rl.transport.wire import PeerInfo, RendezvousPayload, WriteEntry + + +class TransportPlan: + def __init__(self, agent: NixlAgentWrapper, peer_metadata: list[p2p_pb2.WorkerMetadata], slots: list[Slot]) -> None: + """Initialize the transport plan with the given agent, peer metadata, and slots. + This method will: + 1. Decode the peer_metadata into PeerInfo objects to create the peers list + 2. Builds the writes list by calling build_writes on each slot + 3. Prepares the local and remote preps by calling prep_local and prep_remote on the agent for each slot and peer + """ + peers: list[PeerInfo] = [] + for meta in peer_metadata: + payload = msgspec.msgpack.decode(meta.nixl_metadata, type=RendezvousPayload) + tensor_addrs = {td.name: (td.addr, td.size, td.device_id) for td in meta.tensors} + peers.append( + PeerInfo( + agent_name=payload.agent_name, + agent_metadata=payload.agent_metadata, + tensor_addrs=tensor_addrs, + expert_map=payload.expert_map, + ) + ) + + self.peers: list[PeerInfo] = peers + self.writes: list[WriteEntry] = [] + self.local_preps: dict[str, Any] = {} + self.remote_preps: dict[tuple[str, str], Any] = {} + + for slot in slots: + self.writes.extend(slot.build_writes(peers)) + + for peer in self.peers: + name = agent.add_remote_agent(peer.agent_metadata) + agent.make_connection(name) + + for slot in slots: + for buf_key, tensor, num_chunks in slot.buffers: + total_bytes = tensor.numel() * tensor.element_size() + chunk_bytes = total_bytes // num_chunks + base_ptr = tensor.data_ptr() + dev = tensor.get_device() + local_descs = [(base_ptr + i * chunk_bytes, chunk_bytes, dev) for i in range(num_chunks)] + self.local_preps[buf_key] = agent.prep_local(local_descs) + + for peer in self.peers: + for slot in slots: + for buf_key, descs in slot.peer_chunk_descs(peer).items(): + if not descs: + continue # peer owns no chunks for this slot (e.g. unowned experts) + self.remote_preps[(peer.agent_name, buf_key)] = agent.prep_remote(peer.agent_name, descs) + + def prepare_writes(self, state_dict: dict[str, Tensor], slots: list[Slot]) -> Iterator[tuple[Any, Any, WriteEntry]]: + """Iterator yielding (local_prep, remote_prep, write_entry) tuples for each write. To be posted by the caller""" + + for slot in slots: + slot.convert(state_dict) + + torch.cuda.synchronize() + + for entry in self.writes: + local_prep = self.local_preps[entry.local_buffer_key] + remote_prep = self.remote_preps[(entry.peer_name, entry.remote_buffer_key)] + yield (local_prep, remote_prep, entry) diff --git a/src/prime_rl/transport/wire.py b/src/prime_rl/transport/wire.py new file mode 100644 index 0000000000..26b38a2201 --- /dev/null +++ b/src/prime_rl/transport/wire.py @@ -0,0 +1,73 @@ +"""Wire-format types exchanged between the trainer publisher, inference +receiver, and transport plan, all riding on Model Express. + +* :class:`LayoutEntry` — a trainer-side registered buffer that the inference + side needs to narrow into its destination tensor and chunk for RDMA. + Published as part of the trainer's :class:`RendezvousPayload`. +* :class:`PeerInfo` — the trainer's view of one inference peer after both + publishes have landed: NIXL agent name, serialized chunked xfer + descriptors keyed by buffer name, and the ``expert_map``. +* :class:`WriteEntry` — one RDMA WRITE description, produced by a slot + given a peer list and resolved by the transport plan into NIXL prep + handles + ``post_write`` calls. +* :class:`RendezvousPayload` — what gets msgpack-encoded into + :attr:`p2p_pb2.WorkerMetadata.nixl_metadata` so MX can carry both the + raw NIXL agent metadata blob *and* our auxiliary structured fields on + one channel. +""" + +from __future__ import annotations + +import msgspec + + +class LayoutEntry(msgspec.Struct, frozen=True): + slot_key: str + inference_name: str + offset_rows: int + rows: int + num_chunks: int # trainer_ws for sharded buffers, 1 for gathered + + +class PeerInfo(msgspec.Struct): + """One peer's payload after fetching and unpacking via MX. + + ``tensor_addrs`` maps tensor name → ``(base_addr, total_bytes, + device_id)`` for every NIXL-registered buffer the peer published; the + trainer combines this with its own :class:`LayoutEntry` list at + RDMA-prep time to build per-chunk dlists locally — no need for the + peer to round-trip serialized descriptors. ``expert_map`` maps a MoE + prefix to the list of global expert IDs the peer owns. + """ + + agent_name: str + agent_metadata: bytes + tensor_addrs: dict[str, tuple[int, int, int]] + expert_map: dict[str, list[int]] + + +class WriteEntry(msgspec.Struct, frozen=True): + """One RDMA WRITE description, resolved later by the transport plan.""" + + local_buffer_key: str + local_chunk_idx: int + peer_name: str + remote_buffer_key: str + remote_chunk_idx: int + tag: str # diagnostics + + +class RendezvousPayload(msgspec.Struct): + """Packed blob carried in :attr:`p2p_pb2.WorkerMetadata.nixl_metadata`. + + Trainer publishes ``agent_metadata`` + ``agent_name`` + ``layout``. + Inference publishes ``agent_metadata`` + ``agent_name`` + ``expert_map``. + Each side publishes once; the trainer chunks remote dlists locally at + RDMA-prep time using inference's tensor base addresses (from + :attr:`p2p_pb2.WorkerMetadata.tensors`) plus its own ``layout``. + """ + + agent_metadata: bytes + agent_name: str = "" + layout: list[LayoutEntry] = msgspec.field(default_factory=list) + expert_map: dict[str, list[int]] = msgspec.field(default_factory=dict) diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index beb41e8ab6..d633bbbba0 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -2,6 +2,7 @@ import asyncio import os +from collections.abc import Callable from itertools import cycle from pathlib import Path from typing import Protocol, runtime_checkable @@ -47,7 +48,13 @@ async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> N """Wait for inference pool to be ready.""" ... - async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: + async def update_weights( + self, + weight_dir: Path | None, + lora_name: str | None = None, + step: int = 0, + on_engines_paused: Callable[[], None] | None = None, + ) -> None: """Update weights on all inference servers.""" ... @@ -119,8 +126,16 @@ async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> N ) await maybe_check_has_model(self._admin_clients, model_name, skip_model_check=self._skip_model_check) - async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: - await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step) + async def update_weights( + self, + weight_dir: Path | None, + lora_name: str | None = None, + step: int = 0, + on_engines_paused: Callable[[], None] | None = None, + ) -> None: + await update_weights( + self._admin_clients, weight_dir, lora_name=lora_name, step=step, on_engines_paused=on_engines_paused + ) def get_metrics(self) -> dict[str, float]: return {} @@ -336,6 +351,7 @@ async def update_weights( weight_dir: Path | None, lora_name: str | None = None, step: int = 0, + on_engines_paused: Callable[[], None] | None = None, ) -> None: """Update weights on static inference servers. @@ -343,8 +359,10 @@ async def update_weights( weight update, then resumes. This ensures all DP workers are idle and can participate in the collective weight transfer. - Note: The server-side /update_weights endpoint automatically resets the prefix cache - to invalidate any cached KV states computed with the old weights. + Args: + on_engines_paused: Optional callback invoked after all engines are + paused but before the weight transfer begins. Used by the NIXL+MX + path to signal the trainer that it's safe to start the RDMA push. """ logger = get_logger() @@ -358,11 +376,12 @@ async def _update_weights(admin_client: AsyncClient, weight_dir: str | None) -> response = await admin_client.post("/update_weights", json={"weight_dir": weight_dir}) response.raise_for_status() - # Pause engines so all DP workers drain in-flight work and can join the NCCL broadcast await _pause_engines(admin_clients) try: - # Create ready marker before servers enter receive path (used by NCCL broadcast) + if on_engines_paused is not None: + on_engines_paused() + if weight_dir is not None: nccl_ready_file = weight_dir / NCCL_READY_MARKER nccl_ready_file.parent.mkdir(parents=True, exist_ok=True) @@ -495,3 +514,28 @@ async def _init_nccl_broadcast(admin_client: AsyncClient, rank_offset: int) -> N for client_num, admin_client in enumerate(admin_clients) ] ) + + +async def init_nixl_mx_broadcast( + admin_clients: list[AsyncClient], + host: str, + port: int, + inference_world_size: int, +) -> None: + """Initialize NIXL+MX receivers on all inference servers.""" + logger = get_logger() + gpus_per_server = inference_world_size // len(admin_clients) + + logger.info( + f"Initializing NIXL+MX broadcast: {len(admin_clients)} servers, " + f"inference_world_size={inference_world_size}, gpus_per_server={gpus_per_server}" + ) + + async def _init(admin_client: AsyncClient, rank_offset: int) -> None: + response = await admin_client.post( + "/init_nixl_mx", + json={"host": host, "port": port, "rank_offset": rank_offset}, + ) + response.raise_for_status() + + await asyncio.gather(*[_init(admin_client, i * gpus_per_server) for i, admin_client in enumerate(admin_clients)]) diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 5bd75dbf90..e4b4de38a8 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -11,6 +11,7 @@ import asyncio import socket import time +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path from typing import Literal @@ -511,7 +512,13 @@ async def wait_for_ready(self, model_name: str = "", timeout: int | None = None, raise TimeoutError(f"Timed out waiting for {min_servers} ready servers (got {self.num_ready_servers})") - async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: + async def update_weights( + self, + weight_dir: Path | None, + lora_name: str | None = None, + step: int = 0, + on_engines_paused: Callable[[], None] | None = None, + ) -> None: if lora_name is None: raise ValueError("Elastic inference pool requires LoRA training (lora_name must be set)") await self.sync_weights(weight_dir, lora_name, step) diff --git a/tests/unit/train/models/conversions/__init__.py b/tests/unit/train/models/conversions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/train/models/conversions/test_qwen3_moe.py b/tests/unit/train/models/conversions/test_qwen3_moe.py new file mode 100644 index 0000000000..6a16ea9d06 --- /dev/null +++ b/tests/unit/train/models/conversions/test_qwen3_moe.py @@ -0,0 +1,41 @@ +"""Conversion-spec resolution for Qwen3 MoE — bf16 and FP8 variants.""" + +from __future__ import annotations + +import pytest + +from prime_rl.trainer.models.conversions import resolve, select_default_conversion +from prime_rl.trainer.models.qwen3_moe.converting_qwen3_moe import ( + BASE_LAYER_CONVERSION_SPEC, + DENSE_LAYER_CONVERSION_SPEC, + NON_LAYER_CONVERSION_SPEC, + SPARSE_LAYER_CONVERSION_SPEC, +) + + +@pytest.fixture( + params=[ + pytest.param(("Qwen/Qwen3-235B-A22B-Thinking-2507", "passthrough"), id="bf16"), + pytest.param(("Qwen/Qwen3-235B-A22B-Thinking-2507-FP8", "fp8_128x128"), id="fp8"), + ] +) +def qwen3_variant(request) -> tuple[str, str]: + return request.param + + +def test_select_default_conversion(qwen3_variant): + model_name, expected = qwen3_variant + assert select_default_conversion(model_name) == expected + + +def test_specs_resolve_correctly(qwen3_variant): + _, default = qwen3_variant + for spec in ( + BASE_LAYER_CONVERSION_SPEC + + DENSE_LAYER_CONVERSION_SPEC + + SPARSE_LAYER_CONVERSION_SPEC + + NON_LAYER_CONVERSION_SPEC + ): + entry = resolve(spec.conversion.conversion_type, default) + expected = spec.conversion.conversion_type or default + assert entry.fn.__name__ == expected, f"{spec.dst} -> {entry.fn.__name__}" diff --git a/tests/unit/train/models/slots/__init__.py b/tests/unit/train/models/slots/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/train/models/slots/test_qwen3_moe.py b/tests/unit/train/models/slots/test_qwen3_moe.py new file mode 100644 index 0000000000..d9719725c0 --- /dev/null +++ b/tests/unit/train/models/slots/test_qwen3_moe.py @@ -0,0 +1,287 @@ +"""Slot allocation for tiny Qwen3 MoE configs — bf16 and FP8 inference, single-rank GPU. + +Verifies the dispatch (ShardedSlot vs GatheredSlot vs ExpertSlot), per-slot +buffer keys, layout payloads, write entries, and an end-to-end materialize +roundtrip on the qkv-projection slot. +""" + +from __future__ import annotations + +import pytest +import torch +from transformers import Qwen3MoeConfig + +from prime_rl.trainer.models.fp8 import BLOCK_SIZE, ceil_div +from prime_rl.trainer.models.qwen3_moe.converting_qwen3_moe import ( + BASE_LAYER_CONVERSION_SPEC, + DENSE_LAYER_CONVERSION_SPEC, + NON_LAYER_CONVERSION_SPEC, + SPARSE_LAYER_CONVERSION_SPEC, +) +from prime_rl.trainer.models.slots import ( + SMALL_NON_EXPERT_BYTES, + ExpertSlot, + GatheredSlot, + ShardedSlot, + build_slots_for_conversion_spec, +) +from prime_rl.trainer.parallel_dims import ParallelDims +from prime_rl.transport.wire import PeerInfo + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="slot allocation lives on CUDA") + + +def _tiny_config() -> Qwen3MoeConfig: + # Hidden = 4096 so qkv (4096 rows for q, 1024 for k, 1024 for v) is large + # enough (~ 8MiB) to clear SMALL_NON_EXPERT_BYTES and land as ShardedSlot + # under trivial parallelism (dp_shard=1). + return Qwen3MoeConfig( + num_hidden_layers=2, + hidden_size=4096, + intermediate_size=128, + moe_intermediate_size=64, + num_attention_heads=32, + num_key_value_heads=8, + num_experts=4, + num_experts_per_tok=2, + decoder_sparse_step=1, + mlp_only_layers=[], + vocab_size=128, + max_position_embeddings=128, + ) + + +def _state_dict(config: Qwen3MoeConfig) -> dict[str, torch.Tensor]: + h, mh = config.hidden_size, config.moe_intermediate_size + n_q, n_kv = config.num_attention_heads, config.num_key_value_heads + head_dim = h // n_q + e, v = config.num_experts, config.vocab_size + sd: dict[str, torch.Tensor] = {} + for i in range(config.num_hidden_layers): + p = f"model.layers.{i}" + sd[f"{p}.input_layernorm.weight"] = torch.empty(h, device="cuda") + sd[f"{p}.post_attention_layernorm.weight"] = torch.empty(h, device="cuda") + sd[f"{p}.self_attn.q_norm.weight"] = torch.empty(head_dim, device="cuda") + sd[f"{p}.self_attn.k_norm.weight"] = torch.empty(head_dim, device="cuda") + sd[f"{p}.self_attn.q_proj.weight"] = torch.empty(n_q * head_dim, h, device="cuda") + sd[f"{p}.self_attn.k_proj.weight"] = torch.empty(n_kv * head_dim, h, device="cuda") + sd[f"{p}.self_attn.v_proj.weight"] = torch.empty(n_kv * head_dim, h, device="cuda") + sd[f"{p}.self_attn.o_proj.weight"] = torch.empty(h, n_q * head_dim, device="cuda") + sd[f"{p}.mlp.router.gate.weight"] = torch.empty(e, h, device="cuda") + sd[f"{p}.mlp.experts.w1"] = torch.empty(e, mh, h, device="cuda") + sd[f"{p}.mlp.experts.w2"] = torch.empty(e, h, mh, device="cuda") + sd[f"{p}.mlp.experts.w3"] = torch.empty(e, mh, h, device="cuda") + sd["model.embed_tokens.weight"] = torch.empty(v, h, device="cuda") + sd["model.norm.weight"] = torch.empty(h, device="cuda") + sd["lm_head.weight"] = torch.empty(v, h, device="cuda") + return sd + + +def _trivial_dims() -> ParallelDims: + return ParallelDims(dp_replicate=1, dp_shard=1, cp=1, pp=1, ep=1, world_size=1) + + +@pytest.fixture +def tiny_state() -> tuple[Qwen3MoeConfig, dict[str, torch.Tensor]]: + config = _tiny_config() + return config, _state_dict(config) + + +@pytest.fixture( + params=[ + pytest.param(("passthrough", torch.bfloat16), id="bf16"), + pytest.param(("fp8_128x128", torch.bfloat16), id="fp8"), + ] +) +def inference_target(request) -> tuple[str, torch.dtype]: + return request.param + + +def _is_dense_layer(config, layer_idx: int) -> bool: + if layer_idx in config.mlp_only_layers: + return True + if config.num_experts == 0: + return True + return (layer_idx + 1) % config.decoder_sparse_step != 0 + + +def _build(config, sd, default, base): + slots = [] + dims = _trivial_dims() + for i in range(config.num_hidden_layers): + prefix = f"model.layers.{i}" + tail = DENSE_LAYER_CONVERSION_SPEC if _is_dense_layer(config, i) else SPARSE_LAYER_CONVERSION_SPEC + for spec in BASE_LAYER_CONVERSION_SPEC + tail: + slots.extend( + build_slots_for_conversion_spec( + spec, prefix=prefix, state_dict=sd, parallel_dims=dims, default_conversion=default, base_dtype=base + ) + ) + for spec in NON_LAYER_CONVERSION_SPEC: + slots.extend( + build_slots_for_conversion_spec( + spec, prefix="", state_dict=sd, parallel_dims=dims, default_conversion=default, base_dtype=base + ) + ) + return slots + + +def test_dispatch_picks_expected_slot_types(tiny_state, inference_target): + """Layernorms / router gate land as GatheredSlot; large projections as + ShardedSlot; expert specs as ExpertSlot. Independent of inference target. + """ + config, sd = tiny_state + default, base = inference_target + slots = _build(config, sd, default, base) + by_key = {s.slot_key: s for s in slots} + + # Layernorms are 1D and tiny → GatheredSlot. + assert isinstance(by_key["model.layers.0.input_layernorm.weight"], GatheredSlot) + assert isinstance(by_key["model.norm.weight"], GatheredSlot) + + # Large 2D projections clear SMALL_NON_EXPERT_BYTES and divide cleanly → ShardedSlot. + q = by_key["model.layers.0.self_attn.q_proj.weight"] + assert isinstance(q, ShardedSlot) + assert q.weight.numel() * q.weight.element_size() >= SMALL_NON_EXPERT_BYTES + + # Stacked-expert specs are always ExpertSlot. + assert isinstance(by_key["model.layers.0.mlp.experts.w13_weight"], ExpertSlot) + assert isinstance(by_key["model.layers.0.mlp.experts.w2_weight"], ExpertSlot) + + +def test_qkv_three_sources_yield_three_slots(tiny_state, inference_target): + """A fused qkv ConversionSpec produces three independent slots + (one per source) with offset_rows accumulated along the fused dim. + """ + config, sd = tiny_state + default, base = inference_target + slots = _build(config, sd, default, base) + qkv_slots = sorted( + ( + s + for s in slots + if s.slot_key.startswith("model.layers.0.self_attn.") + and s.slot_key.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight")) + ), + key=lambda s: s.offset_rows, + ) + q, k, v = qkv_slots + assert q.source_name == "model.layers.0.self_attn.q_proj.weight" + assert k.source_name == "model.layers.0.self_attn.k_proj.weight" + assert v.source_name == "model.layers.0.self_attn.v_proj.weight" + assert q.offset_rows == 0 + assert k.offset_rows == q.rows + assert v.offset_rows == q.rows + k.rows + # All three share the fused inference destination. + assert q.inference_name == k.inference_name == v.inference_name == "model.layers.0.self_attn.qkv_proj.weight" + + +def test_fp8_only_quantized_slots_carry_scale(tiny_state): + """Pinned (passthrough) specs never get scale buffers under FP8 inference.""" + config, sd = tiny_state + slots = _build(config, sd, "fp8_128x128", torch.bfloat16) + by_key = {s.slot_key: s for s in slots} + + # Pinned in conversion spec: layernorms, router gate, model.norm, lm_head. + for key in [ + "model.layers.0.input_layernorm.weight", + "model.layers.0.mlp.router.gate.weight", + "model.norm.weight", + "lm_head.weight", + ]: + assert by_key[key].scale is None, f"{key} must stay non-quantized" + + # Default-resolution: scale buffer present; weight is fp8. + q = by_key["model.layers.0.self_attn.q_proj.weight"] + assert q.weight.dtype == torch.float8_e4m3fn + assert q.scale is not None and q.scale.dtype == torch.float32 + # Scale uses per-source naming on trainer side, fused on inference side. + assert q.scale_key == "model.layers.0.self_attn.q_proj.weight_scale_inv" + assert q.inference_scale_name == "model.layers.0.self_attn.qkv_proj.weight_scale_inv" + + +def test_expert_slot_buffer_layout_and_writes(tiny_state): + config, sd = tiny_state + slots = _build(config, sd, "fp8_128x128", torch.bfloat16) + w13 = next(s for s in slots if s.slot_key == "model.layers.0.mlp.experts.w13_weight") + assert isinstance(w13, ExpertSlot) + e = config.num_experts + mh = config.moe_intermediate_size + h = config.hidden_size + # w1 + w3 fused along cat_dim=1 → (e, 2*mh, h). + assert tuple(w13.weight.shape) == (e, 2 * mh, h) + assert w13.scale is not None + assert tuple(w13.scale.shape) == (e, ceil_div(2 * mh, BLOCK_SIZE), ceil_div(h, BLOCK_SIZE)) + # All experts owned in single-rank EP=1 setup. + assert w13.owned_global_experts == list(range(e)) + # buffers report num_chunks == num_local_experts for the 3D path. + assert [(k, t.shape, n) for k, t, n in w13.buffers] == [ + ("model.layers.0.mlp.experts.w13_weight", w13.weight.shape, e), + ("model.layers.0.mlp.experts.w13_weight_scale_inv", w13.scale.shape, e), + ] + # Expert layout uses peer.expert_map, so layout_payload is empty. + assert w13.layout_payload() == [] + + # Build writes against a fake peer that owns experts 1 and 2. + peer = PeerInfo( + agent_name="inference-test-r0", + agent_metadata=b"", + tensor_addrs={}, + expert_map={"model.layers.0.mlp.experts": [1, 2]}, + ) + writes = w13.build_writes([peer]) + # Two experts × (weight + scale) = 4 writes; peer owns global experts 1, 2 at chunks 0, 1. + assert len(writes) == 4 + by_chunks = sorted((w.local_chunk_idx, w.remote_chunk_idx, w.local_buffer_key) for w in writes) + assert by_chunks == [ + (1, 0, "model.layers.0.mlp.experts.w13_weight"), + (1, 0, "model.layers.0.mlp.experts.w13_weight_scale_inv"), + (2, 1, "model.layers.0.mlp.experts.w13_weight"), + (2, 1, "model.layers.0.mlp.experts.w13_weight_scale_inv"), + ] + + +def test_sharded_slot_writes_target_my_rank_chunk_on_every_peer(tiny_state): + config, sd = tiny_state + slots = _build(config, sd, "passthrough", torch.bfloat16) + q = next(s for s in slots if s.slot_key == "model.layers.0.self_attn.q_proj.weight") + assert isinstance(q, ShardedSlot) + peers = [ + PeerInfo(agent_name=f"inference-test-r{r}", agent_metadata=b"", tensor_addrs={}, expert_map={}) + for r in range(3) + ] + writes = q.build_writes(peers) + # One write per (peer, buffer); single-rank trainer → my_rank=0, no scale buffer. + assert {(w.peer_name, w.remote_chunk_idx) for w in writes} == { + ("inference-test-r0", 0), + ("inference-test-r1", 0), + ("inference-test-r2", 0), + } + + +def test_gathered_slot_round_robin_writes_when_single_rank(tiny_state): + """With trainer_ws=1, a single trainer rank owns every gathered write.""" + config, sd = tiny_state + slots = _build(config, sd, "passthrough", torch.bfloat16) + norm = next(s for s in slots if s.slot_key == "model.layers.0.input_layernorm.weight") + assert isinstance(norm, GatheredSlot) + peers = [PeerInfo(agent_name=f"inf-r{r}", agent_metadata=b"", tensor_addrs={}, expert_map={}) for r in range(4)] + writes = norm.build_writes(peers) + # One write per peer; remote_chunk_idx=0 (gathered → single chunk). + assert {w.peer_name for w in writes} == {"inf-r0", "inf-r1", "inf-r2", "inf-r3"} + assert all(w.remote_chunk_idx == 0 for w in writes) + + +def test_materialize_roundtrip_on_sharded_slot(tiny_state): + config, _ = tiny_state + sd = _state_dict(config) + g = torch.Generator(device="cuda").manual_seed(0) + for k, v in sd.items(): + sd[k] = torch.randn(v.shape, generator=g, dtype=torch.float32, device="cuda") + + slots = _build(config, sd, "passthrough", torch.bfloat16) + q = next(s for s in slots if s.slot_key == "model.layers.0.self_attn.q_proj.weight") + q.convert(sd) + # Single-rank: ShardedSlot's weight equals the source cast to bf16. + expected = sd["model.layers.0.self_attn.q_proj.weight"].to(torch.bfloat16) + torch.testing.assert_close(q.weight, expected) diff --git a/tests/unit/transport/__init__.py b/tests/unit/transport/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/transport/conftest.py b/tests/unit/transport/conftest.py new file mode 100644 index 0000000000..363a43dcd3 --- /dev/null +++ b/tests/unit/transport/conftest.py @@ -0,0 +1,46 @@ +"""Session-scoped Model Express server brought up via docker-compose. + +Skips the test session entirely if Docker isn't on the PATH so transport +tests don't false-fail in environments without Docker. + +Uses a fixed compose project name (``prime-rl-mx-test``) so a previous +session that crashed before teardown doesn't leak port 29501 — ``up -d +--wait`` is idempotent against an already-healthy stack. +""" + +from __future__ import annotations + +import shutil +import subprocess +from pathlib import Path + +import pytest + +COMPOSE_FILE = Path(__file__).resolve().parents[3] / "docker" / "modelexpress" / "docker-compose.yml" +COMPOSE_PROJECT = "prime-rl-mx-test" + + +@pytest.fixture(scope="session") +def mx_server() -> str: + """Bring up the ME stack (server + redis), tear it down at session end. Yields ``host:port``.""" + if shutil.which("docker") is None: + pytest.skip("docker not on PATH") + if not COMPOSE_FILE.is_file(): + pytest.skip(f"compose file not found at {COMPOSE_FILE}") + + up = subprocess.run( + ["docker", "compose", "-f", str(COMPOSE_FILE), "-p", COMPOSE_PROJECT, "up", "-d", "--build", "--wait"], + capture_output=True, + text=True, + ) + if up.returncode != 0: + pytest.skip(f"docker compose up failed:\n{up.stderr}") + + try: + yield "localhost:29501" + finally: + subprocess.run( + ["docker", "compose", "-f", str(COMPOSE_FILE), "-p", COMPOSE_PROJECT, "down", "--remove-orphans"], + capture_output=True, + text=True, + ) diff --git a/tests/unit/transport/test_mx_rendezvous.py b/tests/unit/transport/test_mx_rendezvous.py new file mode 100644 index 0000000000..cde5dc631b --- /dev/null +++ b/tests/unit/transport/test_mx_rendezvous.py @@ -0,0 +1,114 @@ +"""End-to-end tests for :class:`MxRendezvous` against a live Model Express server. + +The ``mx_server`` fixture (in ``conftest.py``) brings up the stack via +docker-compose. Each test uses a unique model name so the Redis backend +(persistent across the session) doesn't cross-pollute peer counts. +""" + +from __future__ import annotations + +import uuid + +import pytest +from modelexpress import p2p_pb2 +from modelexpress.client import MxClient + +from prime_rl.transport.mx_rendezvous import MxRendezvous + + +@pytest.fixture +def model_name() -> str: + """Unique synthetic model name per test so peer lists don't leak across tests.""" + return f"prime-rl-test/{uuid.uuid4().hex}" + + +@pytest.fixture +def client(mx_server: str) -> MxClient: + c = MxClient(server_url=mx_server) + yield c + c.close() + + +def test_publish_returns_stable_mx_source_id(client, model_name): + rdz = MxRendezvous(client=client, role="trainer", rank=0, peer_world_size=1, model_name=model_name) + descriptor = p2p_pb2.TensorDescriptor(name="w0", addr=0, size=0, device_id=0, dtype="bfloat16") + sid = rdz.publish(nixl_metadata=b"trainer-md", tensors=[descriptor]) + assert sid + assert rdz.mx_source_id == sid + # Re-publishing the same identity returns the same hash. + assert rdz.publish(nixl_metadata=b"trainer-md", tensors=[descriptor]) == sid + + +def test_trainer_and_inference_have_distinct_mx_source_ids(client, model_name): + """Role lives in extra_parameters, so the two roles hash to different ids.""" + trainer = MxRendezvous(client=client, role="trainer", rank=0, peer_world_size=1, model_name=model_name) + inference = MxRendezvous(client=client, role="inference", rank=0, peer_world_size=1, model_name=model_name) + t_sid = trainer.publish(nixl_metadata=b"t", tensors=[]) + i_sid = inference.publish(nixl_metadata=b"i", tensors=[]) + assert t_sid != i_sid + + +def test_cross_role_discovery(client, model_name): + """2 trainers + 2 inference workers each find the other side and only the other side.""" + trainers = [ + MxRendezvous(client=client, role="trainer", rank=r, peer_world_size=2, model_name=model_name) for r in range(2) + ] + inferences = [ + MxRendezvous(client=client, role="inference", rank=r, peer_world_size=2, model_name=model_name) + for r in range(2) + ] + for t in trainers: + t.publish(nixl_metadata=f"t-{t.rank}".encode(), tensors=[]) + for i in inferences: + i.publish(nixl_metadata=f"i-{i.rank}".encode(), tensors=[]) + + t_peers = trainers[0].wait_for_peers(timeout=5) + i_peers = inferences[0].wait_for_peers(timeout=5) + + assert {p.worker_rank for p in t_peers} == {0, 1} + assert {p.worker_rank for p in i_peers} == {0, 1} + assert {p.worker_id for p in t_peers} == {i.worker_id for i in inferences} + assert {p.worker_id for p in i_peers} == {t.worker_id for t in trainers} + + +def test_fetch_peer_preserves_nixl_metadata(client, model_name): + trainer = MxRendezvous(client=client, role="trainer", rank=0, peer_world_size=1, model_name=model_name) + inference = MxRendezvous(client=client, role="inference", rank=0, peer_world_size=1, model_name=model_name) + descriptor = p2p_pb2.TensorDescriptor(name="w", addr=0, size=0, device_id=0, dtype="bfloat16") + trainer.publish(nixl_metadata=b"agent-bytes-from-trainer", tensors=[descriptor]) + inference.publish(nixl_metadata=b"agent-bytes-from-inference", tensors=[]) + + [t_ref] = inference.wait_for_peers(timeout=5) + t_meta = inference.fetch_peer(t_ref) + assert t_meta.nixl_metadata == b"agent-bytes-from-trainer" + assert {td.name for td in t_meta.tensors} == {"w"} + + +def test_wait_for_peers_times_out_when_none_arrive(client, model_name): + rdz = MxRendezvous(client=client, role="trainer", rank=0, peer_world_size=1, model_name=model_name) + rdz.publish(nixl_metadata=b"x", tensors=[]) + with pytest.raises(TimeoutError, match="inference"): + rdz.wait_for_peers(timeout=1.5, poll_interval=0.5) + + +def test_status_filter_gates_discovery(client, model_name): + """Inference only counts trainers in READY status.""" + trainer = MxRendezvous(client=client, role="trainer", rank=0, peer_world_size=1, model_name=model_name) + inference = MxRendezvous(client=client, role="inference", rank=0, peer_world_size=1, model_name=model_name) + trainer.publish(nixl_metadata=b"t", tensors=[]) + inference.publish(nixl_metadata=b"i", tensors=[]) + + # Trainer hasn't called set_status(READY) yet — inference should time out + # when it filters on READY. + with pytest.raises(TimeoutError): + inference.wait_for_peers(status=p2p_pb2.SOURCE_STATUS_READY, timeout=1.0, poll_interval=0.3) + + trainer.set_status(p2p_pb2.SOURCE_STATUS_READY) + peers = inference.wait_for_peers(status=p2p_pb2.SOURCE_STATUS_READY, timeout=5) + assert len(peers) == 1 + + +def test_set_status_before_publish_raises(client, model_name): + rdz = MxRendezvous(client=client, role="trainer", rank=0, peer_world_size=1, model_name=model_name) + with pytest.raises(RuntimeError, match="publish"): + rdz.set_status(p2p_pb2.SOURCE_STATUS_READY) diff --git a/tests/unit/transport/test_nixl_agent.py b/tests/unit/transport/test_nixl_agent.py new file mode 100644 index 0000000000..d5281ec821 --- /dev/null +++ b/tests/unit/transport/test_nixl_agent.py @@ -0,0 +1,65 @@ +"""Show that a NIXL agent's metadata + tensor descriptor make it through +Model Express end-to-end: register on the trainer side, publish, fetch +back from the inference side, byte-compare. +""" + +from __future__ import annotations + +import uuid + +import pytest +import torch +from modelexpress.client import MxClient + +from prime_rl.transport.mx_rendezvous import MxRendezvous + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="nixl agent needs CUDA"), +] +pytest.importorskip("nixl_cu13") + +from prime_rl.transport.nixl_agent import NixlAgentWrapper, make_agent_name # noqa: E402 + + +@pytest.fixture +def model_name() -> str: + return f"prime-rl-test/{uuid.uuid4().hex}" + + +@pytest.fixture +def client(mx_server: str) -> MxClient: + c = MxClient(server_url=mx_server) + yield c + c.close() + + +def test_publish_nixl_agent_metadata_through_mx(client, model_name): + agent = NixlAgentWrapper(name=make_agent_name("trainer", 0)) + weight = torch.randn(64, 64, dtype=torch.bfloat16, device="cuda") + agent.register_tensor(weight) + + metadata = agent.get_metadata() + assert metadata, "expected non-empty NIXL agent metadata bytes" + + descriptor = agent.make_tensor_descriptor("model.layers.0.self_attn.qkv_proj.weight", weight) + assert descriptor.addr == weight.data_ptr() + assert descriptor.size == weight.numel() * weight.element_size() + assert descriptor.dtype == "bfloat16" + + trainer = MxRendezvous(client=client, role="trainer", rank=0, peer_world_size=1, model_name=model_name) + inference = MxRendezvous(client=client, role="inference", rank=0, peer_world_size=1, model_name=model_name) + trainer.publish(nixl_metadata=metadata, tensors=[descriptor]) + inference.publish(nixl_metadata=b"inference-side-stub", tensors=[]) + + [trainer_ref] = inference.wait_for_peers(timeout=5) + fetched = inference.fetch_peer(trainer_ref) + + assert fetched.nixl_metadata == metadata + assert len(fetched.tensors) == 1 + td = fetched.tensors[0] + assert (td.name, td.addr, td.size, td.dtype) == ( + descriptor.name, + descriptor.addr, + descriptor.size, + descriptor.dtype, + ) diff --git a/uv.lock b/uv.lock index b85ead13d3..34a791cab9 100644 --- a/uv.lock +++ b/uv.lock @@ -15,7 +15,7 @@ conflicts = [[ ]] [options] -exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. +exclude-newer = "2026-05-13T14:35:31.084230641Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -71,6 +71,7 @@ overrides = [ { name = "nvidia-cudnn-cu12", specifier = ">=9.15" }, { name = "nvidia-cutlass-dsl", specifier = ">=4.4.1" }, { name = "openenv-core" }, + { name = "protobuf", specifier = ">=6.31.1" }, { name = "torch", specifier = ">=2.9.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, ] @@ -350,6 +351,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/95/adcb68e20c34162e9135f370d6e31737719c2b6f94bc953fe7ed1f10fe21/authlib-1.7.2-py2.py3-none-any.whl", hash = "sha256:3e1faedc9d87e7d56a164eca3ccb6ace0d61b94abe83e92242f8dc8bba9b4a9f", size = 259548, upload-time = "2026-05-06T08:10:21.436Z" }, ] +[[package]] +name = "azure-core" +version = "1.41.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/f3/b416179e408990df5db0d516283022dde0f5d0111d98c1a848e41853e81c/azure_core-1.41.0.tar.gz", hash = "sha256:f46ff5dfcd230f25cf1c19e8a34b8dc08a337b2503e268bb600a16c00db8ad5a", size = 381042, upload-time = "2026-05-07T23:30:54.302Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/db/325c6d7312d2200251c52323878281045aaffcb5586612296484e4280eaa/azure_core-1.41.0-py3-none-any.whl", hash = "sha256:522b4011e8180b1a3dcd2024396a4e7fe9ac37fb8597db47163d230b5efe892d", size = 220920, upload-time = "2026-05-07T23:30:56.357Z" }, +] + +[[package]] +name = "azure-identity" +version = "1.25.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-core", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "cryptography", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "msal", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "msal-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/0e/3a63efb48aa4a5ae2cfca61ee152fbcb668092134d3eb8bfda472dd5c617/azure_identity-1.25.3.tar.gz", hash = "sha256:ab23c0d63015f50b630ef6c6cf395e7262f439ce06e5d07a64e874c724f8d9e6", size = 286304, upload-time = "2026-03-13T01:12:20.892Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/9a/417b3a533e01953a7c618884df2cb05a71e7b68bdbce4fbdb62349d2a2e8/azure_identity-1.25.3-py3-none-any.whl", hash = "sha256:f4d0b956a8146f30333e071374171f3cfa7bdb8073adb8c3814b65567aa7447c", size = 192138, upload-time = "2026-03-13T01:12:22.951Z" }, +] + +[[package]] +name = "azure-storage-blob" +version = "12.28.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-core", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "cryptography", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "isodate", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/24/072ba8e27b0e2d8fec401e9969b429d4f5fc4c8d4f0f05f4661e11f7234a/azure_storage_blob-12.28.0.tar.gz", hash = "sha256:e7d98ea108258d29aa0efbfd591b2e2075fa1722a2fae8699f0b3c9de11eff41", size = 604225, upload-time = "2026-01-06T23:48:57.282Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl", hash = "sha256:00fb1db28bf6a7b7ecaa48e3b1d5c83bfadacc5a678b77826081304bd87d6461", size = 431499, upload-time = "2026-01-06T23:48:58.995Z" }, +] + [[package]] name = "backoff" version = "2.2.1" @@ -438,6 +483,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/a7/1a31561d10a089fcb46fe286766dd4e053a12f6e23b4fd1c26478aff2475/boltons-21.0.0-py2.py3-none-any.whl", hash = "sha256:b9bb7b58b2b420bbe11a6025fdef6d3e5edc9f76a42fb467afe7ca212ef9948b", size = 193723, upload-time = "2021-05-17T01:20:20.023Z" }, ] +[[package]] +name = "boto3" +version = "1.43.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "jmespath", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "s3transfer", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/37/78c630d1308964aa9abf44951d9c4df776546ff37251ec2434944e205c4e/boto3-1.43.6.tar.gz", hash = "sha256:e6315effaf12b890b99956e6f8e2c3000a3f64e4ee91943cec3895ce9a836afb", size = 113153, upload-time = "2026-05-07T20:49:59.694Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/e2/3c2eef44f55eafab256836d1d9479bd6a74f70c26cbfdc0639a0e23e4327/boto3-1.43.6-py3-none-any.whl", hash = "sha256:179601ec2992726a718053bf41e43c223ceba397d31ceab11f64d9c910d9fc3a", size = 140502, upload-time = "2026-05-07T20:49:57.8Z" }, +] + +[[package]] +name = "botocore" +version = "1.43.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "python-dateutil", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "urllib3", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/a7/23d0f5028011455096a1eeac0ddf3cbe147b3e855e127342f8202552194d/botocore-1.43.6.tar.gz", hash = "sha256:b1e395b347356860398da42e61c808cf1e34b6fa7180cf2b9d87d986e1a06ba0", size = 15336070, upload-time = "2026-05-07T20:49:48.14Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/c8/6f47223840e8d8cfa8c9f7c0ec1b77970417f257fc885169ff4f6326ce09/botocore-1.43.6-py3-none-any.whl", hash = "sha256:b6d1fdbc6f65a5fe0b7e947823aa37535d3f39f3ba4d21110fab1f55bbbcc04b", size = 15017094, upload-time = "2026-05-07T20:49:44.964Z" }, +] + [[package]] name = "bracex" version = "2.6" @@ -1601,6 +1674,87 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/e6/4129d9a3baa72d747533bb33376543ccadd9a7f9944e5a6e3ae2e245f5d6/glom-25.12.0-py3-none-any.whl", hash = "sha256:b9f21e77f71a6576a43864e85066b8cc3f0f778d0d50961563f8981377a6dcb1", size = 103295, upload-time = "2025-12-29T06:29:06.074Z" }, ] +[[package]] +name = "google-api-core" +version = "2.30.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "googleapis-common-protos", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "proto-plus", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "protobuf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/16/ce/502a57fb0ec752026d24df1280b162294b22a0afb98a326084f9a979138b/google_api_core-2.30.3.tar.gz", hash = "sha256:e601a37f148585319b26db36e219df68c5d07b6382cff2d580e83404e44d641b", size = 177001, upload-time = "2026-04-10T00:41:28.035Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/15/e56f351cf6ef1cfea58e6ac226a7318ed1deb2218c4b3cc9bd9e4b786c5a/google_api_core-2.30.3-py3-none-any.whl", hash = "sha256:a85761ba72c444dad5d611c2220633480b2b6be2521eca69cca2dbb3ffd6bfe8", size = 173274, upload-time = "2026-04-09T22:57:16.198Z" }, +] + +[[package]] +name = "google-auth" +version = "2.52.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pyasn1-modules", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d4/f8/80d2493cbedece1c623dc3e3cb1883300871af0dcdae254409522985ac23/google_auth-2.52.0.tar.gz", hash = "sha256:01f30e1a9e3638698d89464f5e603ce29d18e1c0e63ec31ac570aba4e164aaf5", size = 335027, upload-time = "2026-05-07T19:45:24.033Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/fc/2cdc74252746f547f81ff3f02d4d4234a3f411b5de5b61af97e633a060b9/google_auth-2.52.0-py3-none-any.whl", hash = "sha256:aee92803ba0ff93a70a3b8a35c7b4797837751cd6380b63ff38372b98f3ed627", size = 245614, upload-time = "2026-05-07T19:45:21.914Z" }, +] + +[[package]] +name = "google-cloud-core" +version = "2.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "google-auth", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/dd/1eef226e470369b26824a505c34482c0b493bc35fe8e0c6b003b5feca21a/google_cloud_core-2.6.0.tar.gz", hash = "sha256:e76149739f90fac1fc6757c09f47eaccb3145b54adbd7759b0f7c4b235f46c83", size = 36001, upload-time = "2026-05-07T08:04:04.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/4a/98da8930ab109c73d9a5d13782a9ebb81ea8c111f6d534a567b71d23e52b/google_cloud_core-2.6.0-py3-none-any.whl", hash = "sha256:6d63ac8e5eca6d9e4319d0a1e2265fadcd7f1049904378caecfa01cf52dd869e", size = 29390, upload-time = "2026-05-07T08:02:34.672Z" }, +] + +[[package]] +name = "google-cloud-storage" +version = "3.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "google-auth", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "google-cloud-core", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "google-crc32c", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "google-resumable-media", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/47/205eb8e9a1739b5345843e5a425775cbdc472cc38e7eda082ba5b8d02450/google_cloud_storage-3.10.1.tar.gz", hash = "sha256:97db9aa4460727982040edd2bd13ff3d5e2260b5331ad22895802da1fc2a5286", size = 17309950, upload-time = "2026-03-23T09:35:23.409Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/ff/ca9ab2417fa913d75aae38bf40bf856bb2749a604b2e0f701b37cfcd23cc/google_cloud_storage-3.10.1-py3-none-any.whl", hash = "sha256:a72f656759b7b99bda700f901adcb3425a828d4a29f911bc26b3ea79c5b1217f", size = 324453, upload-time = "2026-03-23T09:35:21.368Z" }, +] + +[[package]] +name = "google-crc32c" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, + { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, +] + +[[package]] +name = "google-resumable-media" +version = "2.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-crc32c", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/4b/0b235beccc310d0a48adbc7246b719d173cca6c88c572dfa4b090e39143c/google_resumable_media-2.9.0.tar.gz", hash = "sha256:f7cfb224846a9dd444d125115dfbe8ef02a2b893e78f087762fe716a255a734b", size = 2164534, upload-time = "2026-05-07T08:04:44.236Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/73/3518e63deb1667c5409a4579e28daf5e84479a87a72c547e0487f7883dcd/google_resumable_media-2.9.0-py3-none-any.whl", hash = "sha256:c8901e88e389af8bed64d9696c74d8bad961865eb2236e13e0bfca9bb0a65ca3", size = 81507, upload-time = "2026-05-07T08:03:23.809Z" }, +] + [[package]] name = "googleapis-common-protos" version = "1.75.0" @@ -1816,6 +1970,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/a5/33b49ba7bea7c41bb37f74ec0f8beea0831e052330196633fe2c77516ea6/huggingface_hub-1.14.0-py3-none-any.whl", hash = "sha256:efe075535c62e130b30e836b138e13785f6f043d1f0539e0a39aa411a99e90b8", size = 661479, upload-time = "2026-05-06T14:14:32.029Z" }, ] +[[package]] +name = "humanize" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/66/a3921783d54be8a6870ac4ccffcd15c4dc0dd7fcce51c6d63b8c63935276/humanize-4.15.0.tar.gz", hash = "sha256:1dd098483eb1c7ee8e32eb2e99ad1910baefa4b75c3aff3a82f4d78688993b10", size = 83599, upload-time = "2025-12-20T20:16:13.19Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl", hash = "sha256:b1186eb9f5a9749cd9cb8565aee77919dd7c8d076161cf44d70e59e3301e1769", size = 132203, upload-time = "2025-12-20T20:16:11.67Z" }, +] + [[package]] name = "hyperframe" version = "6.1.0" @@ -1994,6 +2157,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/6d/0d9848617b9f753b87f214f1c682592f7ca42de085f564352f10f0843026/ipywidgets-8.1.8-py3-none-any.whl", hash = "sha256:ecaca67aed704a338f88f67b1181b58f821ab5dc89c1f0f5ef99db43c1c2921e", size = 139808, upload-time = "2025-11-01T21:18:10.956Z" }, ] +[[package]] +name = "isodate" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/4d/e940025e2ce31a8ce1202635910747e5a87cc3a6a6bb2d00973375014749/isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6", size = 29705, upload-time = "2024-10-08T23:04:11.5Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15", size = 22320, upload-time = "2024-10-08T23:04:09.501Z" }, +] + [[package]] name = "jaraco-classes" version = "3.4.0" @@ -2798,6 +2970,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/26/c7aea197f1719f31d0dd686eb4475982fe9efd7668ce259cb52b62c676b6/model_hosting_container_standards-0.1.15-py3-none-any.whl", hash = "sha256:849e08c4732203ee861c8c24966b4e916ea4420fa324b430f7f74a1e1fe8811a", size = 125418, upload-time = "2026-05-05T18:22:27.819Z" }, ] +[[package]] +name = "modelexpress" +version = "0.3.0" +source = { git = "https://github.com/ai-dynamo/modelexpress.git?subdirectory=modelexpress_client%2Fpython&rev=b0c94ed#b0c94ed61c65d2c2355a18508977941b9946d8b5" } +dependencies = [ + { name = "grpcio", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "huggingface-hub", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "nixl", extra = ["cu12"], marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "protobuf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pydantic", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "runai-model-streamer", extra = ["azure", "gcs", "s3"], marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] + [[package]] name = "more-itertools" version = "11.0.2" @@ -2816,6 +3003,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, ] +[[package]] +name = "msal" +version = "1.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pyjwt", extra = ["crypto"], marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/cb/b02b0f748ac668922364ccb3c3bff5b71628a05f5adfec2ba2a5c3031483/msal-1.36.0.tar.gz", hash = "sha256:3f6a4af2b036b476a4215111c4297b4e6e236ed186cd804faefba23e4990978b", size = 174217, upload-time = "2026-04-09T10:20:33.525Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/d3/414d1f0a5f6f4fe5313c2b002c54e78a3332970feb3f5fed14237aa17064/msal-1.36.0-py3-none-any.whl", hash = "sha256:36ecac30e2ff4322d956029aabce3c82301c29f0acb1ad89b94edcabb0e58ec4", size = 121547, upload-time = "2026-04-09T10:20:32.336Z" }, +] + +[[package]] +name = "msal-extensions" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "msal", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/01/99/5d239b6156eddf761a636bded1118414d161bd6b7b37a9335549ed159396/msal_extensions-1.3.1.tar.gz", hash = "sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4", size = 23315, upload-time = "2025-03-14T23:51:03.902Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/75/bd9b7bb966668920f06b200e84454c8f3566b102183bc55c5473d96cb2b9/msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca", size = 20583, upload-time = "2025-03-14T23:51:03.016Z" }, +] + [[package]] name = "msgpack" version = "1.1.2" @@ -2920,6 +3133,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bb/a4/10e80e623790d3fb070e966b36bced009419d483c92ae0df6645230f606f/nixl-0.10.1-py3-none-any.whl", hash = "sha256:616465673dae5180d296525a03237af4cd5f2c00c3228d185bc06dbe621509b7", size = 6680, upload-time = "2026-03-03T19:55:44.848Z" }, ] +[package.optional-dependencies] +cu12 = [ + { name = "nixl-cu12", version = "0.10.1", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "nixl-cu12", version = "0.10.1", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] + [[package]] name = "nixl-cu12" version = "0.10.1" @@ -3930,6 +4149,7 @@ all = [ { name = "flash-attn", version = "2.8.3+cu128torch2.11", source = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.8.3+cu128torch2.11-cp312-cp312-linux_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "flash-attn-3", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "flash-attn-4", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "modelexpress", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "nixl", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "nixl-cu12", version = "0.10.1", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "quack-kernels", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -3938,6 +4158,7 @@ all = [ disagg = [ { name = "deep-ep", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "deep-gemm", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "modelexpress", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "nixl", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "nixl-cu12", version = "0.10.1", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm-router", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -4032,6 +4253,7 @@ requires-dist = [ { name = "mini-swe-agent-plus", marker = "extra == 'envs'", editable = "deps/research-environments/environments/mini_swe_agent_plus" }, { name = "mini-swe-agent-plus-rlm", marker = "extra == 'envs'", editable = "deps/research-environments/environments/mini_swe_agent_plus_rlm" }, { name = "mmlu-pro", marker = "extra == 'envs'", editable = "deps/research-environments/environments/mmlu_pro" }, + { name = "modelexpress", marker = "extra == 'disagg'", git = "https://github.com/ai-dynamo/modelexpress.git?subdirectory=modelexpress_client%2Fpython&rev=b0c94ed" }, { name = "nixl", marker = "extra == 'disagg'" }, { name = "nixl-cu12", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" }, { name = "numpy", specifier = ">=2.2.6" }, @@ -4188,6 +4410,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/ed/1cdcab6ba3d6ab7feca11fc14f0eeea80755bb53ef4e892079f31b10a25f/propcache-0.5.2-py3-none-any.whl", hash = "sha256:be1ddfcbb376e3de5d2e2db1d58d6d67463e6b4f9f040c000de8e300295465fe", size = 14036, upload-time = "2026-05-08T21:02:10.673Z" }, ] +[[package]] +name = "proto-plus" +version = "1.28.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/56/e647b0c675392d2da368da7b6f158f7368b18542fd6f7d7400a2f39de000/proto_plus-1.28.0.tar.gz", hash = "sha256:38e5696342835b08fc116f30a25665b29531cda9d5d5643e9b81fc312385abd9", size = 57221, upload-time = "2026-05-07T08:04:50.811Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/20/b122d4626976acb81132036d2ad1bb35a1a8775fceb837ec30964622516a/proto_plus-1.28.0-py3-none-any.whl", hash = "sha256:a630604310899e73c59ec302e5765c058d412b2f090b9c79c8822589f14955b8", size = 50410, upload-time = "2026-05-07T08:03:31.962Z" }, +] + [[package]] name = "protobuf" version = "6.33.6" @@ -4291,6 +4525,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/7a/82c363caa145fff88fb475da50d3bf52bb024f61917be5424c3392eaf878/pyarrow-24.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:25ea65d868eb04015cd18e6df2fbe98f07e5bda2abefabcb88fce39a947716f6", size = 51929490, upload-time = "2026-04-21T10:47:55.981Z" }, ] +[[package]] +name = "pyasn1" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5f/6583902b6f79b399c9c40674ac384fd9cd77805f9e6205075f828ef11fb2/pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf", size = 148685, upload-time = "2026-03-17T01:06:53.382Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/a0/7d793dce3fa811fe047d6ae2431c672364b462850c6235ae306c0efd025f/pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde", size = 83997, upload-time = "2026-03-17T01:06:52.036Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + [[package]] name = "pybase64" version = "1.4.3" @@ -4926,6 +5181,81 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/7a/a9ba7f98c7a575978698f4230c5e8cc54bbc761af34f560818f933dafa0c/ruff-0.15.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9cae0f92bd5700d1213188b31cd3bdd2b315361296d10b96b8e2337d3d11f53e", size = 11447765, upload-time = "2026-04-24T18:17:09.755Z" }, ] +[[package]] +name = "runai-model-streamer" +version = "0.15.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "humanize", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/af/bf5a01398c7dec217427eb148772d3b5df430f23f57c0fcc5e94ef11a0f3/runai_model_streamer-0.15.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0a189e8b844fe6b8fb108ca27b311add5b76c55d59e1d7822bf94f5c2d81388d", size = 608124, upload-time = "2026-04-30T17:57:42.992Z" }, + { url = "https://files.pythonhosted.org/packages/95/67/a36882fed852f209e59e4e58e87ef15063d1cf828941d0510cf5b37b9b40/runai_model_streamer-0.15.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:aaa02c5063abb326e0a5f7969e8ee4d850eaf5752e50d7e5fee9d38c6b98d10b", size = 608335, upload-time = "2026-04-30T17:57:44.68Z" }, +] + +[package.optional-dependencies] +azure = [ + { name = "runai-model-streamer-azure", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +gcs = [ + { name = "runai-model-streamer-gcs", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +s3 = [ + { name = "runai-model-streamer-s3", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] + +[[package]] +name = "runai-model-streamer-azure" +version = "0.15.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-identity", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "azure-storage-blob", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/4a/15d2477277f4a3c48dccadf96c195af0f48b98b2626e3d1f3a459ac2c5e0/runai_model_streamer_azure-0.15.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6e90ba5259b1571f08b11805e22587336f2891e9c4b36dba42a054d56d13d2cc", size = 5851347, upload-time = "2026-04-30T17:58:03.398Z" }, + { url = "https://files.pythonhosted.org/packages/a7/fc/a989edd63e8abde21c1dc1193cbcb237bb5b03d4498f96210b9c632fdd86/runai_model_streamer_azure-0.15.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:27bee4c08d9be24cb223c9c7338f55c730c41a70508665c095d300d33c9f5b72", size = 5570107, upload-time = "2026-04-30T17:58:05.175Z" }, +] + +[[package]] +name = "runai-model-streamer-gcs" +version = "0.15.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "google-cloud-storage", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/6c/d8988071fe416387d3e090fea49c15f55f97766766df5a8e1c152a76cfce/runai_model_streamer_gcs-0.15.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b48ab124dd0400d1ee795ff48ce022551a14d44c655b5a09c25d0ea62ba4f8fd", size = 23226423, upload-time = "2026-04-30T17:57:55.707Z" }, + { url = "https://files.pythonhosted.org/packages/80/4c/63d3ddc67a6c52c5e78f1c7f0db17600a245adba8e0f4dfd6986bc17929b/runai_model_streamer_gcs-0.15.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:15665c4c8cb43e4fcee344bf0171dc4fb33c0fa643dc873a4487d4fe78be7f08", size = 23170930, upload-time = "2026-04-30T17:57:58.516Z" }, +] + +[[package]] +name = "runai-model-streamer-s3" +version = "0.15.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "boto3", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/8a/274c765434851189ed25ebef2e32be829c221ddc5ff8985a1fb24cd364d1/runai_model_streamer_s3-0.15.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4303d3331b7cee46319d0bbf5e0011c6b36938705cc94c3d3f35039cf4f07ce3", size = 6180151, upload-time = "2026-04-30T17:57:48.882Z" }, + { url = "https://files.pythonhosted.org/packages/a8/57/e2d7b5ba3ef70f55828a26faf3b3626407d46033c03bffd7eb45d247face/runai_model_streamer_s3-0.15.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:80fc380af719207ac765d3cd97f3391dbec1e96fa63cc429d6a00e932be75ef2", size = 5921555, upload-time = "2026-04-30T17:57:50.76Z" }, +] + +[[package]] +name = "s3transfer" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/ec/7c692cde9125b77e84b307354d4fb705f98b8ccad59a036d5957ca75bfc3/s3transfer-0.17.0.tar.gz", hash = "sha256:9edeb6d1c3c2f89d6050348548834ad8289610d886e5bf7b7207728bd43ce33a", size = 155337, upload-time = "2026-04-29T22:07:36.33Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/72/c6c32d2b657fa3dad1de340254e14390b1e334ce38268b7ad51abda3c8c2/s3transfer-0.17.0-py3-none-any.whl", hash = "sha256:ce3801712acf4ad3e89fb9990df97b4972e93f4b3b0004d214be5bce12814c20", size = 86811, upload-time = "2026-04-29T22:07:34.966Z" }, +] + [[package]] name = "safetensors" version = "0.7.0"