Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
328011e
Something
S1ro1 Apr 30, 2026
724cf75
Feat: config scaffold
S1ro1 May 2, 2026
cdf54ec
Feat: MX dependencies
S1ro1 May 2, 2026
bd1d27c
Feat: sbatch wiring
S1ro1 May 2, 2026
2d8d4ba
Feat: wire conversion spec
S1ro1 May 2, 2026
cdf86ab
Feat: allocate slots
S1ro1 May 2, 2026
4c17566
Feat: randezvous
S1ro1 May 2, 2026
321dc16
Feat: NIXl agent
S1ro1 May 2, 2026
897917f
Feat: trainer publisher
S1ro1 May 2, 2026
c299027
Feat: cleanupu + style
S1ro1 May 2, 2026
6186726
chore: drop unrelated sparse_mla and sonic_ep work from this branch
S1ro1 May 2, 2026
8fe0549
Feat: cleanup + sharding
S1ro1 May 2, 2026
b8e266c
Nit: style
S1ro1 May 2, 2026
4d50aab
Feat: rendezvous metadata
S1ro1 May 2, 2026
c3b9d80
Feat: local metadata exchange
S1ro1 May 2, 2026
b53b8c9
Feat: trainer wiring
S1ro1 May 2, 2026
691a98f
Feat: inference worker
S1ro1 May 2, 2026
51061e8
Feat: orchestrator wiring + cleanup of rendezvous
S1ro1 May 2, 2026
26864da
Fix: ShardedSpec computing twice sharded size
S1ro1 May 2, 2026
39f5379
Fix: stable file race condition
S1ro1 May 3, 2026
4af68d2
Fix: add LD_LIBARRY_PATH to templates
S1ro1 May 3, 2026
977ad07
Fix: move to model.state_dict()
S1ro1 May 3, 2026
97c29c8
Fix: passthrough on embed tokens
S1ro1 May 3, 2026
ca0781e
Fix: cast state dict
S1ro1 May 3, 2026
7589967
Fix: race condition in update_weights_from_path
S1ro1 May 3, 2026
861c909
Fix: rendezvous was 0 rank waiting
S1ro1 May 3, 2026
4467157
Fix: hybrid_lb + MX port
S1ro1 May 3, 2026
50317b9
Fix: revert port changes
S1ro1 May 3, 2026
7a62fcf
Feat: remove NCCL/STABLE file path, instead use NIXL status
S1ro1 May 3, 2026
bca3055
Feat: huge refactor we cooking 🚀🚀🚀
S1ro1 May 5, 2026
bd73c3a
Feat: fix race condition
S1ro1 May 5, 2026
947df0a
fix: stabilize GLM NIXL transfer snapshot
S1ro1 May 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,4 @@ debug_I2_zero_band

outputs/

third_party/
third_party/
90 changes: 90 additions & 0 deletions configs/glm51_math_nixl_dpep16_10node/rl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
output_dir = "outputs/glm51-math-nixl-dpep16-10node-nixlfix2"
seq_len = 2048
max_steps = 500

[log]
level = "debug"

[wandb]
project = "glm51-math"
name = "glm51-math-nixl-dpep16-10node-nixlfix2"

[weight_broadcast]
type = "nixl_mx"
timeout = 12000

[deployment]
type = "multi_node"
num_train_nodes = 8
num_infer_nodes = 2
gpus_per_node = 8

[slurm]
job_name = "glm51-math-nixl-dpep16"
partition = "cluster"

[trainer]
dist_timeout_seconds = 7200

[trainer.model]
name = "zai-org/GLM-5.1"
impl = "custom"
attn = "flash_attention_2"
fused_lm_head_token_chunk_size = 2048
optim_cpu_offload = true
ep = 8

[trainer.model.compile]

[trainer.model.ac]
freq = 1

[trainer.model.ac_offloading]
max_inflight_activations = 1

[trainer.optim]
type = "sign_sgd"
lr = 1e-6
weight_decay = 0.1

[orchestrator]
batch_size = 64
rollouts_per_example = 16
oversampling_factor = 2
max_off_policy_steps = 8
filters = []

[orchestrator.model]
name = "zai-org/GLM-5.1-FP8"

[orchestrator.renderer]
name = "glm-5.1"

[orchestrator.train.sampling]
max_completion_tokens = 2048

[[orchestrator.train.env]]
id = "math-env"
name = "hendrycks-math"
args = { dataset_name = "PrimeIntellect/Hendrycks-Math", dataset_subset = "default", math_verify_max_workers = 128, math_verify_timeout = 60 }

[orchestrator.buffer]
easy_threshold = 1.0
hard_threshold = 0.0
seed = 42

[inference]
enable_expert_parallel = true
all2all_backend = "allgather_reducescatter"
gpu_memory_utilization = 0.80
enable_eplb = false
use_deep_gemm = true

[inference.model]
name = "zai-org/GLM-5.1-FP8"
tool_call_parser = "glm47"
max_model_len = 2048

[inference.parallel]
tp = 1
dp = 16
31 changes: 31 additions & 0 deletions docker/modelexpress/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
services:
redis:
image: redis:8-alpine
network_mode: host
command: ["redis-server", "--port", "29502"]
healthcheck:
test: ["CMD", "redis-cli", "-p", "29502", "ping"]
interval: 5s
retries: 3
start_period: 5s
restart: "no"

modelexpress-server:
build:
context: https://github.com/ai-dynamo/modelexpress.git#b0c94ed61c65d2c2355a18508977941b9946d8b5
network_mode: host
environment:
- MODEL_EXPRESS_SERVER_PORT=29501
- MODEL_EXPRESS_LOG_LEVEL=info
- MX_METADATA_BACKEND=redis
- REDIS_URL=redis://localhost:29502
- MX_HEARTBEAT_TIMEOUT_SECS=86400
healthcheck:
test: ["CMD", "bash", "-c", "</dev/tcp/localhost/29501"]
interval: 5s
retries: 12
start_period: 5s
depends_on:
redis:
condition: service_healthy
restart: "no"
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ModelConfig(BaseModelConfig):


class WeightBroadcastConfig(BaseConfig):
type: Literal["nccl", "filesystem"] = "filesystem"
type: Literal["nccl", "filesystem", "nixl_mx"] = "filesystem"
"""Weight broadcast transport."""


Expand Down
21 changes: 20 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,27 @@ class NCCLWeightBroadcastConfig(BaseConfig):
"""Total inference GPUs across all servers. Used by ``init_nccl_broadcast`` to compute per-server rank offsets."""


class NIXLMxWeightBroadcastConfig(BaseConfig):
"""Configures NIXL + Model Express weight broadcast."""

type: Literal["nixl_mx"] = "nixl_mx"

host: str = "localhost"
"""Host for the Model Express rendezvous server."""

port: int = 29501
"""Port for the Model Express rendezvous server."""

timeout: int = 1200
"""Timeout in seconds for rendezvous and per-step transfers."""

inference_world_size: int = Field(1, ge=1)
"""Total inference GPUs across all servers."""


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down
53 changes: 40 additions & 13 deletions packages/prime-rl-configs/src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from prime_rl.configs.orchestrator import (
NCCLWeightBroadcastConfig as OrchestratorNCCLWeightBroadcastConfig,
)
from prime_rl.configs.orchestrator import (
NIXLMxWeightBroadcastConfig as OrchestratorNIXLMxWeightBroadcastConfig,
)
from prime_rl.configs.orchestrator import (
OrchestratorConfig,
)
Expand All @@ -31,6 +34,9 @@
from prime_rl.configs.trainer import (
NCCLWeightBroadcastConfig as TrainerNCCLWeightBroadcastConfig,
)
from prime_rl.configs.trainer import (
NIXLMxWeightBroadcastConfig as TrainerNIXLMxWeightBroadcastConfig,
)
from prime_rl.utils.config import BaseConfig, find_package_resource
from prime_rl.utils.validation import (
propagate_shared_fields,
Expand Down Expand Up @@ -113,17 +119,20 @@ class SharedModelConfig(BaseConfig):


class SharedWeightBroadcastConfig(BaseConfig):
type: Literal["nccl", "filesystem"] = "filesystem"
type: Literal["nccl", "filesystem", "nixl_mx"] = "filesystem"
"""Weight broadcast transport."""

host: str = "localhost"
"""Host for weight broadcast rendezvous."""

port: int = 29501
"""Port for NCCL weight broadcast."""
"""Port for weight broadcast rendezvous."""

timeout: int = 1200
"""Timeout in seconds for NCCL weight broadcast."""
"""Timeout in seconds for weight broadcast."""

quantize_in_weight_transfer: bool = False
"""Use kernel-format FP8 quantized NCCL transfer for weight updates. When disabled, uses default HF checkpoint-format transfer."""
"""Use kernel-format FP8 quantized NCCL transfer for weight updates. Only applies when type is ``nccl``."""


class BaseDeploymentConfig(BaseConfig):
Expand Down Expand Up @@ -301,10 +310,11 @@ def validate_no_teacher_in_multinode(self):
@model_validator(mode="after")
def validate_enough_devices_for_nccl(self):
if self.deployment.type == "single_node":
if self.trainer.weight_broadcast.type == "nccl":
if self.trainer.weight_broadcast.type in ("nccl", "nixl_mx"):
if self.deployment.num_train_gpus + self.deployment.num_infer_gpus < 2:
raise ValueError(
"NCCL weight broadcast requires at least 2 GPUs to build the broadcast process group."
f"{self.trainer.weight_broadcast.type} weight broadcast requires at least 2 GPUs "
"to build the broadcast process group."
)
return self

Expand Down Expand Up @@ -359,12 +369,14 @@ def auto_setup_weight_broadcast(self):
self.trainer.weight_broadcast = TrainerNCCLWeightBroadcastConfig(
type=self.weight_broadcast.type,
inference_world_size=inference_world_size,
host=self.weight_broadcast.host,
port=self.weight_broadcast.port,
timeout=self.weight_broadcast.timeout,
quantize_in_weight_transfer=self.weight_broadcast.quantize_in_weight_transfer,
)
self.orchestrator.weight_broadcast = OrchestratorNCCLWeightBroadcastConfig(
type=self.weight_broadcast.type,
host=self.weight_broadcast.host,
port=self.weight_broadcast.port,
timeout=self.weight_broadcast.timeout,
inference_world_size=inference_world_size,
Expand All @@ -373,6 +385,24 @@ def auto_setup_weight_broadcast(self):
elif self.weight_broadcast.type == "filesystem":
self.trainer.weight_broadcast = TrainerFileSystemWeightBroadcastConfig()
self.orchestrator.weight_broadcast = OrchestratorFileSystemWeightBroadcastConfig()
elif self.weight_broadcast.type == "nixl_mx":
inference_world_size = self.inference.parallel.dp * self.inference.parallel.tp if self.inference else 1
inference_model_name = self.inference.model.name if self.inference else self.model.name
self.trainer.weight_broadcast = TrainerNIXLMxWeightBroadcastConfig(
type=self.weight_broadcast.type,
inference_world_size=inference_world_size,
host=self.weight_broadcast.host,
port=self.weight_broadcast.port,
timeout=self.weight_broadcast.timeout,
inference_model_name=inference_model_name,
)
self.orchestrator.weight_broadcast = OrchestratorNIXLMxWeightBroadcastConfig(
type=self.weight_broadcast.type,
inference_world_size=inference_world_size,
host=self.weight_broadcast.host,
port=self.weight_broadcast.port,
timeout=self.weight_broadcast.timeout,
)
if self.inference is not None:
self.inference.weight_broadcast = InferenceWeightBroadcastConfig(type=self.weight_broadcast.type)

Expand Down Expand Up @@ -569,16 +599,15 @@ def auto_setup_deployment(self):
if self.inference.api_server_count == 1 and dp_per_node > 1:
self.inference.api_server_count = dp_per_node

if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl":
if self.weight_broadcast is not None and self.weight_broadcast.type in ("nccl", "nixl_mx"):
# Compute inference_world_size from actual worker count per server:
# each api_server runs tp workers that participate in collective_rpc.
api_server_count = self.inference.api_server_count if self.inference else 1
tp = self.inference.parallel.tp if self.inference else 1
total_infer_workers = self.deployment.total_infer_nodes * api_server_count * tp
assert self.trainer.weight_broadcast.type == "nccl"
self.trainer.weight_broadcast.host = "0.0.0.0"
if self.trainer.weight_broadcast.type == "nccl":
self.trainer.weight_broadcast.host = "0.0.0.0"
self.trainer.weight_broadcast.inference_world_size = total_infer_workers
assert self.orchestrator.weight_broadcast.type == "nccl"
self.orchestrator.weight_broadcast.inference_world_size = total_infer_workers

return self
Expand All @@ -601,10 +630,8 @@ def auto_setup_disaggregated_inference(self):
)

total_infer_gpus = self.deployment.total_infer_nodes * self.deployment.gpus_per_node
if self.weight_broadcast is not None and self.weight_broadcast.type == "nccl":
assert self.trainer.weight_broadcast.type == "nccl"
if self.weight_broadcast is not None and self.weight_broadcast.type in ("nccl", "nixl_mx"):
self.trainer.weight_broadcast.inference_world_size = total_infer_gpus
assert self.orchestrator.weight_broadcast.type == "nccl"
self.orchestrator.weight_broadcast.inference_world_size = total_infer_gpus

return self
Expand Down
25 changes: 24 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,31 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig):
"""Use kernel-format FP8 quantized NCCL transfer for weight updates. When disabled, uses default HF checkpoint-format transfer."""


class NIXLMxWeightBroadcastConfig(BaseWeightBroadcastConfig):
"""Configures NIXL (UCX/RDMA) weight transfer with Model Express rendezvous."""

type: Literal["nixl_mx"] = "nixl_mx"

host: str = "localhost"
"""Host for the Model Express rendezvous server."""

port: int = 29501
"""Port for the Model Express rendezvous server."""

timeout: int = 1200
"""Timeout in seconds for rendezvous and per-step transfers."""

# TODO: Should not be configurable, but auto-inferred
inference_world_size: int = 1
"""Number of GPUs used for inference."""

inference_model_name: str = ""
"""HF model name or local path of the inference target."""


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ disagg = [
"nixl",
"nixl-cu12 ; platform_machine == 'x86_64'",
"vllm-router ; platform_machine == 'x86_64'",
"modelexpress",
]
gpt-oss = [
"kernels",
Expand Down Expand Up @@ -172,6 +173,7 @@ override-dependencies = [
"transformers>=5.1.0.dev0",
"torch>=2.9.0",
"openenv-core",
"protobuf>=6.31.1",
]

[tool.uv.exclude-newer-package]
Expand Down Expand Up @@ -238,6 +240,7 @@ vllm = [
deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" }
deep-gemm = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" }
nixl-cu12 = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" }
modelexpress = { git = "https://github.com/ai-dynamo/modelexpress.git", subdirectory = "modelexpress_client/python", rev = "b0c94ed" }
flash-linear-attention = { git = "https://github.com/fla-org/flash-linear-attention" }
flash_attn_3 = { index = "pytorch-cu128-test" }

Expand Down
22 changes: 22 additions & 0 deletions src/prime_rl/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@
Each shim documents the upstream issue and removal condition.
"""

# ---------------------------------------------------------------------------
# tilelang ships a stub libcudart that proxies to the real CUDA runtime via
# dlsym(RTLD_DEFAULT, ...). If the stub's symbols are the first ones found
# (because nothing has loaded the real libcudart globally yet) its self-check
# fails and the stub aborts — hit the moment any code calls into the
# classic-cudaMalloc MemPool (used for NIXL-registered slot buffers).
#
# Preloading the real library with RTLD_GLOBAL at this very early point —
# before transformers/torch/tilelang are pulled into the process — makes
# dlsym find the real symbols first.
#
# Wrapped in try/except because CDLL can fail on machines without a real
# CUDA runtime (e.g. CI containers).
# ---------------------------------------------------------------------------
import ctypes as _ctypes

try:
_ctypes.CDLL("libcudart.so", mode=_ctypes.RTLD_GLOBAL)
except OSError:
pass


# ---------------------------------------------------------------------------
# ring_flash_attn + transformers >= 5.4
#
Expand Down
Loading