-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Description
How is this issue impacting you?
Data corruption
Share Your Debug Logs
We are seeing silent rank order corruption in all_gather in the following 2 scenarios:
Scenario 1:
- Certain permutation of
CUDA_VISIBLE_DEVICES(e.g.,1,2,3,4,5,6,7,0or3,4,5,6,7,0,1,2) NCCL_P2P_DISABLE=1- NCCL version is 2.26.2 and PyTorch version is 2.7
- Using 8 GPUs on a node.
Scenario 2:
- Certain permutation of
CUDA_VISIBLE_DEVICES(e.g.,1,2,3,4,5,6,7,0or3,4,5,6,7,0,1,2) NCCL_ALGO=NVLS- NCCL version is 2.26.2 and PyTorch version is 2.7
- Using 8 GPUs on a node.
I think the underlying reason to cause the issue for both scenarios might be related, and setting NCCL_P2P_DISABLE=1 just makes the issue more obvious to reproduce at small load.
Steps to Reproduce the Issue
Below is a minimal repro script nccl_debug.py:
# Test PyTorch NCCL
import torch
import torch.distributed as dist
import os
# Initialize distributed to get rank
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()
# Collect information silently
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', 'NOT SET')
nccl_p2p_disable = os.environ.get('NCCL_P2P_DISABLE', 'NOT SET')
# Collect NCCL environment variables
nccl_env_vars = []
for key, value in sorted(os.environ.items()):
if 'NCCL' in key:
nccl_env_vars.append(f"{key}: {value}")
# Collect PyTorch and CUDA information
pytorch_version = torch.__version__
cuda_available = torch.cuda.is_available()
device_count = torch.cuda.device_count()
nccl_version = torch.cuda.nccl.version() if cuda_available else "N/A"
device_names = []
if cuda_available:
for i in range(device_count):
device_names.append(f"Device {i}: {torch.cuda.get_device_name(i)}")
# Set up device
local_rank = rank % device_count
dev = torch.device("cuda", local_rank)
device_name = torch.cuda.get_device_name(dev)
torch.cuda.set_device(dev)
current_device = torch.cuda.current_device()
# Create tensors and perform all_gather
x = torch.full((4,), float(rank), device=dev)
out = [torch.empty_like(x) for _ in range(world_size)]
dist.barrier()
dist.all_gather(out, x)
dist.barrier()
# Verify results
results = []
all_passed = True
for i, t in enumerate(out):
expected = float(i)
actual = t[0].item()
passed = torch.all(t == float(i))
results.append(f" From rank {i}: {t.tolist()} (expected {expected}, got {actual}) - {'✓' if passed else '✗'}")
if not passed:
all_passed = False
# Build entire summary as a single string to print atomically
summary_lines = []
summary_lines.append(f"\n{'='*80}")
summary_lines.append(f"[RANK {rank}/{world_size}] NCCL DEBUG SUMMARY")
summary_lines.append(f"{'='*80}")
summary_lines.append(f"[ENVIRONMENT]")
summary_lines.append(f" CUDA_VISIBLE_DEVICES: {cuda_visible_devices}")
summary_lines.append(f" NCCL_P2P_DISABLE: {nccl_p2p_disable}")
if nccl_env_vars:
summary_lines.append(f" NCCL environment variables:")
for nccl_var in nccl_env_vars:
summary_lines.append(f" {nccl_var}")
summary_lines.append(f"[PYTORCH & CUDA]")
summary_lines.append(f" PyTorch version: {pytorch_version}")
summary_lines.append(f" CUDA available: {cuda_available}")
summary_lines.append(f" CUDA device count: {device_count}")
summary_lines.append(f" NCCL version: {nccl_version}")
for device_name_line in device_names:
summary_lines.append(f" {device_name_line}")
summary_lines.append(f"[DEVICE ASSIGNMENT]")
summary_lines.append(f" Local rank: {local_rank}")
summary_lines.append(f" Assigned device: {dev}")
summary_lines.append(f" Device name: {device_name}")
summary_lines.append(f" Current device: {current_device}")
summary_lines.append(f"[NCCL ALL_GATHER RESULTS]")
summary_lines.append(f" Input tensor: {x.tolist()}")
for result_line in results:
summary_lines.append(result_line)
summary_lines.append(f"[STATUS]")
if all_passed:
summary_lines.append(f" ✓✓✓ SUCCESS: All tests passed on rank {rank}")
else:
summary_lines.append(f" ✗✗✗ FAILURE: Some tests failed on rank {rank}")
summary_lines.append(f"{'='*80}\n")
# Synchronize before printing
dist.barrier()
# Print entire summary as one atomic operation
print("\n".join(summary_lines))
# Synchronize after printing
dist.barrier()
To reproduce the issue, run
Scenario 1:
export CUDA_VISIBLE_DEVICES="1,2,3,4,5,6,7,0"
export NCCL_P2P_DISABLE=1
torchrun --nproc-per-node 8 --standalone nccl_debug.py
Scenario 2:
export CUDA_VISIBLE_DEVICES="1,2,3,4,5,6,7,0"
export NCCL_ALGO=NVLS
torchrun --nproc-per-node 8 --standalone nccl_debug.py
NCCL Version
2.26.2
Your platform details
NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.4
Error Message & Behavior
Here are the results from the repro script:
NCCL 2.26.2 and PyTorch 2.7 Result
W1109 22:17:45.166000 46353 torch/distributed/run.py:766]
W1109 22:17:45.166000 46353 torch/distributed/run.py:766] *****************************************
W1109 22:17:45.166000 46353 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1109 22:17:45.166000 46353 torch/distributed/run.py:766] *****************************************
[rank3]:[W1109 22:17:55.859904220 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank7]:[W1109 22:17:55.369605334 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 7] using GPU 7 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank6]:[W1109 22:17:56.583847830 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 6] using GPU 6 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank0]:[W1109 22:17:56.647513632 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank5]:[W1109 22:17:56.647638104 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 5] using GPU 5 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank1]:[W1109 22:17:56.750041798 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank4]:[W1109 22:17:56.803963466 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 4] using GPU 4 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank2]:[W1109 22:17:56.825434231 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
================================================================================
[RANK 1/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.7.0+cu126
CUDA available: True
CUDA device count: 8
NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
Local rank: 1
Assigned device: cuda:1
Current device: 1
[NCCL ALL_GATHER RESULTS]
Input tensor: [1.0, 1.0, 1.0, 1.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
✗✗✗ FAILURE: Some tests failed on rank 1
================================================================================
================================================================================
[RANK 0/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.7.0+cu126
CUDA available: True
CUDA device count: 8
NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
Local rank: 0
Assigned device: cuda:0
Current device: 0
[NCCL ALL_GATHER RESULTS]
Input tensor: [0.0, 0.0, 0.0, 0.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
✗✗✗ FAILURE: Some tests failed on rank 0
================================================================================
================================================================================
[RANK 5/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.7.0+cu126
CUDA available: True
CUDA device count: 8
NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
Local rank: 5
Assigned device: cuda:5
Current device: 5
[NCCL ALL_GATHER RESULTS]
Input tensor: [5.0, 5.0, 5.0, 5.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
✗✗✗ FAILURE: Some tests failed on rank 5
================================================================================
================================================================================
[RANK 7/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.7.0+cu126
CUDA available: True
CUDA device count: 8
NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
Local rank: 7
Assigned device: cuda:7
Current device: 7
[NCCL ALL_GATHER RESULTS]
Input tensor: [7.0, 7.0, 7.0, 7.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
✗✗✗ FAILURE: Some tests failed on rank 7
================================================================================
================================================================================
[RANK 4/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.7.0+cu126
CUDA available: True
CUDA device count: 8
NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
Local rank: 4
Assigned device: cuda:4
Current device: 4
[NCCL ALL_GATHER RESULTS]
Input tensor: [4.0, 4.0, 4.0, 4.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
✗✗✗ FAILURE: Some tests failed on rank 4
================================================================================
================================================================================
[RANK 3/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.7.0+cu126
CUDA available: True
CUDA device count: 8
NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
Local rank: 3
Assigned device: cuda:3
Current device: 3
[NCCL ALL_GATHER RESULTS]
Input tensor: [3.0, 3.0, 3.0, 3.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
✗✗✗ FAILURE: Some tests failed on rank 3
================================================================================
================================================================================
[RANK 2/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.7.0+cu126
CUDA available: True
CUDA device count: 8
NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
Local rank: 2
Assigned device: cuda:2
Current device: 2
[NCCL ALL_GATHER RESULTS]
Input tensor: [2.0, 2.0, 2.0, 2.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
✗✗✗ FAILURE: Some tests failed on rank 2
================================================================================
================================================================================
[RANK 6/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.7.0+cu126
CUDA available: True
CUDA device count: 8
NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
Local rank: 6
Assigned device: cuda:6
Current device: 6
[NCCL ALL_GATHER RESULTS]
Input tensor: [6.0, 6.0, 6.0, 6.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
✗✗✗ FAILURE: Some tests failed on rank 6
================================================================================
NCCL 2.21.5 and PyTorch 2.6 Result
W1109 22:19:14.214000 4815 torch/distributed/run.py:792] *****************************************
W1109 22:19:14.214000 4815 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1109 22:19:14.214000 4815 torch/distributed/run.py:792] *****************************************
[rank6]:[W1109 22:19:24.164824990 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 6] using GPU 6 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank7]:[W1109 22:19:25.700106572 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 7] using GPU 7 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank1]:[W1109 22:19:25.821279724 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank3]:[W1109 22:19:25.849538390 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank0]:[W1109 22:19:25.944557362 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank5]:[W1109 22:19:25.960418667 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 5] using GPU 5 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank4]:[W1109 22:19:25.015737744 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 4] using GPU 4 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank2]:[W1109 22:19:25.026558830 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
================================================================================
[RANK 3/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device count: 8
NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
Local rank: 3
Assigned device: cuda:3
Current device: 3
[NCCL ALL_GATHER RESULTS]
Input tensor: [3.0, 3.0, 3.0, 3.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
✓✓✓ SUCCESS: All tests passed on rank 3
================================================================================
================================================================================
[RANK 2/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device count: 8
NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
Local rank: 2
Assigned device: cuda:2
Current device: 2
[NCCL ALL_GATHER RESULTS]
Input tensor: [2.0, 2.0, 2.0, 2.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
✓✓✓ SUCCESS: All tests passed on rank 2
================================================================================
================================================================================
[RANK 5/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device count: 8
NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
Local rank: 5
Assigned device: cuda:5
Current device: 5
[NCCL ALL_GATHER RESULTS]
Input tensor: [5.0, 5.0, 5.0, 5.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
✓✓✓ SUCCESS: All tests passed on rank 5
================================================================================
================================================================================
[RANK 6/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device count: 8
NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
Local rank: 6
Assigned device: cuda:6
Current device: 6
[NCCL ALL_GATHER RESULTS]
Input tensor: [6.0, 6.0, 6.0, 6.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
✓✓✓ SUCCESS: All tests passed on rank 6
================================================================================
================================================================================
[RANK 0/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device count: 8
NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
Local rank: 0
Assigned device: cuda:0
Current device: 0
[NCCL ALL_GATHER RESULTS]
Input tensor: [0.0, 0.0, 0.0, 0.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
✓✓✓ SUCCESS: All tests passed on rank 0
================================================================================
================================================================================
[RANK 1/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device count: 8
NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
Local rank: 1
Assigned device: cuda:1
Current device: 1
[NCCL ALL_GATHER RESULTS]
Input tensor: [1.0, 1.0, 1.0, 1.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
✓✓✓ SUCCESS: All tests passed on rank 1
================================================================================
================================================================================
[RANK 7/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device count: 8
NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
Local rank: 7
Assigned device: cuda:7
Current device: 7
[NCCL ALL_GATHER RESULTS]
Input tensor: [7.0, 7.0, 7.0, 7.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
✓✓✓ SUCCESS: All tests passed on rank 7
================================================================================
================================================================================
[RANK 4/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
NCCL_P2P_DISABLE: 1
NCCL environment variables:
NCCL_P2P_DISABLE: 1
NCCL_VERSION: 2.17.1-1
NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
NV_LIBNCCL_PACKAGE_NAME: libnccl2
NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device count: 8
NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
Local rank: 4
Assigned device: cuda:4
Current device: 4
[NCCL ALL_GATHER RESULTS]
Input tensor: [4.0, 4.0, 4.0, 4.0]
From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
✓✓✓ SUCCESS: All tests passed on rank 4
================================================================================
Metadata
Metadata
Assignees
Labels
No labels