Skip to content

[Issue]: all_gather returns tensors in incorrect rank order for certain permutation of CUDA_VISIBLE_DEVICES and NCCL_ALGO=NVLS (regression in NCCL 2.26.2; NCCL 2.21.5 OK) #1906

@xiangxl-a

Description

@xiangxl-a

How is this issue impacting you?

Data corruption

Share Your Debug Logs

We are seeing silent rank order corruption in all_gather in the following 2 scenarios:

Scenario 1:

  1. Certain permutation of CUDA_VISIBLE_DEVICES (e.g., 1,2,3,4,5,6,7,0 or 3,4,5,6,7,0,1,2)
  2. NCCL_P2P_DISABLE=1
  3. NCCL version is 2.26.2 and PyTorch version is 2.7
  4. Using 8 GPUs on a node.

Scenario 2:

  1. Certain permutation of CUDA_VISIBLE_DEVICES (e.g., 1,2,3,4,5,6,7,0 or 3,4,5,6,7,0,1,2)
  2. NCCL_ALGO=NVLS
  3. NCCL version is 2.26.2 and PyTorch version is 2.7
  4. Using 8 GPUs on a node.

I think the underlying reason to cause the issue for both scenarios might be related, and setting NCCL_P2P_DISABLE=1 just makes the issue more obvious to reproduce at small load.

Steps to Reproduce the Issue

Below is a minimal repro script nccl_debug.py:

# Test PyTorch NCCL
import torch
import torch.distributed as dist
import os

# Initialize distributed to get rank
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()

# Collect information silently
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', 'NOT SET')
nccl_p2p_disable = os.environ.get('NCCL_P2P_DISABLE', 'NOT SET')

# Collect NCCL environment variables
nccl_env_vars = []
for key, value in sorted(os.environ.items()):
    if 'NCCL' in key:
        nccl_env_vars.append(f"{key}: {value}")

# Collect PyTorch and CUDA information
pytorch_version = torch.__version__
cuda_available = torch.cuda.is_available()
device_count = torch.cuda.device_count()
nccl_version = torch.cuda.nccl.version() if cuda_available else "N/A"
device_names = []
if cuda_available:
    for i in range(device_count):
        device_names.append(f"Device {i}: {torch.cuda.get_device_name(i)}")

# Set up device
local_rank = rank % device_count
dev = torch.device("cuda", local_rank)
device_name = torch.cuda.get_device_name(dev)
torch.cuda.set_device(dev)
current_device = torch.cuda.current_device()

# Create tensors and perform all_gather
x = torch.full((4,), float(rank), device=dev)
out = [torch.empty_like(x) for _ in range(world_size)]

dist.barrier()
dist.all_gather(out, x)
dist.barrier()

# Verify results
results = []
all_passed = True
for i, t in enumerate(out):
    expected = float(i)
    actual = t[0].item()
    passed = torch.all(t == float(i))
    results.append(f"  From rank {i}: {t.tolist()} (expected {expected}, got {actual}) - {'✓' if passed else '✗'}")
    if not passed:
        all_passed = False

# Build entire summary as a single string to print atomically
summary_lines = []
summary_lines.append(f"\n{'='*80}")
summary_lines.append(f"[RANK {rank}/{world_size}] NCCL DEBUG SUMMARY")
summary_lines.append(f"{'='*80}")
summary_lines.append(f"[ENVIRONMENT]")
summary_lines.append(f"  CUDA_VISIBLE_DEVICES: {cuda_visible_devices}")
summary_lines.append(f"  NCCL_P2P_DISABLE: {nccl_p2p_disable}")
if nccl_env_vars:
    summary_lines.append(f"  NCCL environment variables:")
    for nccl_var in nccl_env_vars:
        summary_lines.append(f"    {nccl_var}")

summary_lines.append(f"[PYTORCH & CUDA]")
summary_lines.append(f"  PyTorch version: {pytorch_version}")
summary_lines.append(f"  CUDA available: {cuda_available}")
summary_lines.append(f"  CUDA device count: {device_count}")
summary_lines.append(f"  NCCL version: {nccl_version}")
for device_name_line in device_names:
    summary_lines.append(f"  {device_name_line}")

summary_lines.append(f"[DEVICE ASSIGNMENT]")
summary_lines.append(f"  Local rank: {local_rank}")
summary_lines.append(f"  Assigned device: {dev}")
summary_lines.append(f"  Device name: {device_name}")
summary_lines.append(f"  Current device: {current_device}")

summary_lines.append(f"[NCCL ALL_GATHER RESULTS]")
summary_lines.append(f"  Input tensor: {x.tolist()}")
for result_line in results:
    summary_lines.append(result_line)

summary_lines.append(f"[STATUS]")
if all_passed:
    summary_lines.append(f"  ✓✓✓ SUCCESS: All tests passed on rank {rank}")
else:
    summary_lines.append(f"  ✗✗✗ FAILURE: Some tests failed on rank {rank}")
summary_lines.append(f"{'='*80}\n")

# Synchronize before printing
dist.barrier()

# Print entire summary as one atomic operation
print("\n".join(summary_lines))

# Synchronize after printing
dist.barrier()

To reproduce the issue, run

Scenario 1:

export CUDA_VISIBLE_DEVICES="1,2,3,4,5,6,7,0"
export NCCL_P2P_DISABLE=1
torchrun --nproc-per-node 8 --standalone nccl_debug.py

Scenario 2:

export CUDA_VISIBLE_DEVICES="1,2,3,4,5,6,7,0"
export NCCL_ALGO=NVLS
torchrun --nproc-per-node 8 --standalone nccl_debug.py

NCCL Version

2.26.2

Your platform details

NVIDIA-SMI 550.163.01 Driver Version: 550.163.01 CUDA Version: 12.4

Error Message & Behavior

Here are the results from the repro script:

NCCL 2.26.2 and PyTorch 2.7 Result

W1109 22:17:45.166000 46353 torch/distributed/run.py:766] 
W1109 22:17:45.166000 46353 torch/distributed/run.py:766] *****************************************
W1109 22:17:45.166000 46353 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1109 22:17:45.166000 46353 torch/distributed/run.py:766] *****************************************
[rank3]:[W1109 22:17:55.859904220 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 3]  using GPU 3 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank7]:[W1109 22:17:55.369605334 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 7]  using GPU 7 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank6]:[W1109 22:17:56.583847830 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 6]  using GPU 6 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank0]:[W1109 22:17:56.647513632 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank5]:[W1109 22:17:56.647638104 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 5]  using GPU 5 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank1]:[W1109 22:17:56.750041798 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank4]:[W1109 22:17:56.803963466 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 4]  using GPU 4 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.
[rank2]:[W1109 22:17:56.825434231 ProcessGroupNCCL.cpp:4715] [PG ID 0 PG GUID 0 Rank 2]  using GPU 2 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can pecify device_id in init_process_group() to force use of a particular device.

================================================================================
[RANK 1/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.7.0+cu126
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
  Local rank: 1
  Assigned device: cuda:1
  Current device: 1
[NCCL ALL_GATHER RESULTS]
  Input tensor: [1.0, 1.0, 1.0, 1.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
  From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
  From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
  From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
  From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
  From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
  From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
  ✗✗✗ FAILURE: Some tests failed on rank 1
================================================================================

================================================================================
[RANK 0/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.7.0+cu126
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
  Local rank: 0
  Assigned device: cuda:0
  Current device: 0
[NCCL ALL_GATHER RESULTS]
  Input tensor: [0.0, 0.0, 0.0, 0.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
  From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
  From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
  From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
  From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
  From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
  From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
  ✗✗✗ FAILURE: Some tests failed on rank 0
================================================================================

================================================================================
[RANK 5/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.7.0+cu126
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
  Local rank: 5
  Assigned device: cuda:5
  Current device: 5
[NCCL ALL_GATHER RESULTS]
  Input tensor: [5.0, 5.0, 5.0, 5.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
  From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
  From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
  From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
  From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
  From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
  From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
  ✗✗✗ FAILURE: Some tests failed on rank 5
================================================================================

================================================================================
[RANK 7/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.7.0+cu126
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
  Local rank: 7
  Assigned device: cuda:7
  Current device: 7
[NCCL ALL_GATHER RESULTS]
  Input tensor: [7.0, 7.0, 7.0, 7.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
  From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
  From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
  From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
  From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
  From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
  From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
  ✗✗✗ FAILURE: Some tests failed on rank 7
================================================================================

================================================================================
[RANK 4/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.7.0+cu126
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
  Local rank: 4
  Assigned device: cuda:4
  Current device: 4
[NCCL ALL_GATHER RESULTS]
  Input tensor: [4.0, 4.0, 4.0, 4.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
  From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
  From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
  From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
  From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
  From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
  From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
  ✗✗✗ FAILURE: Some tests failed on rank 4
================================================================================

================================================================================
[RANK 3/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.7.0+cu126
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
  Local rank: 3
  Assigned device: cuda:3
  Current device: 3
[NCCL ALL_GATHER RESULTS]
  Input tensor: [3.0, 3.0, 3.0, 3.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
  From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
  From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
  From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
  From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
  From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
  From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
  ✗✗✗ FAILURE: Some tests failed on rank 3
================================================================================

================================================================================
[RANK 2/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.7.0+cu126
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
  Local rank: 2
  Assigned device: cuda:2
  Current device: 2
[NCCL ALL_GATHER RESULTS]
  Input tensor: [2.0, 2.0, 2.0, 2.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
  From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
  From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
  From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
  From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
  From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
  From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
  ✗✗✗ FAILURE: Some tests failed on rank 2
================================================================================

================================================================================
[RANK 6/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.7.0+cu126
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 26, 2)
[DEVICE ASSIGNMENT]
  Local rank: 6
  Assigned device: cuda:6
  Current device: 6
[NCCL ALL_GATHER RESULTS]
  Input tensor: [6.0, 6.0, 6.0, 6.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [7.0, 7.0, 7.0, 7.0] (expected 1.0, got 7.0) - ✗
  From rank 2: [1.0, 1.0, 1.0, 1.0] (expected 2.0, got 1.0) - ✗
  From rank 3: [2.0, 2.0, 2.0, 2.0] (expected 3.0, got 2.0) - ✗
  From rank 4: [3.0, 3.0, 3.0, 3.0] (expected 4.0, got 3.0) - ✗
  From rank 5: [4.0, 4.0, 4.0, 4.0] (expected 5.0, got 4.0) - ✗
  From rank 6: [5.0, 5.0, 5.0, 5.0] (expected 6.0, got 5.0) - ✗
  From rank 7: [6.0, 6.0, 6.0, 6.0] (expected 7.0, got 6.0) - ✗
[STATUS]
  ✗✗✗ FAILURE: Some tests failed on rank 6
================================================================================

NCCL 2.21.5 and PyTorch 2.6 Result

W1109 22:19:14.214000 4815 torch/distributed/run.py:792] *****************************************
W1109 22:19:14.214000 4815 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1109 22:19:14.214000 4815 torch/distributed/run.py:792] *****************************************
[rank6]:[W1109 22:19:24.164824990 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 6]  using GPU 6 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank7]:[W1109 22:19:25.700106572 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 7]  using GPU 7 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank1]:[W1109 22:19:25.821279724 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank3]:[W1109 22:19:25.849538390 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 3]  using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank0]:[W1109 22:19:25.944557362 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank5]:[W1109 22:19:25.960418667 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 5]  using GPU 5 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank4]:[W1109 22:19:25.015737744 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 4]  using GPU 4 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank2]:[W1109 22:19:25.026558830 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 2]  using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.

================================================================================
[RANK 3/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.6.0+cu124
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
  Local rank: 3
  Assigned device: cuda:3
  Current device: 3
[NCCL ALL_GATHER RESULTS]
  Input tensor: [3.0, 3.0, 3.0, 3.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
  From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
  From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
  From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
  From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
  From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
  From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
  ✓✓✓ SUCCESS: All tests passed on rank 3
================================================================================

================================================================================
[RANK 2/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.6.0+cu124
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
  Local rank: 2
  Assigned device: cuda:2
  Current device: 2
[NCCL ALL_GATHER RESULTS]
  Input tensor: [2.0, 2.0, 2.0, 2.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
  From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
  From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
  From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
  From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
  From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
  From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
  ✓✓✓ SUCCESS: All tests passed on rank 2
================================================================================

================================================================================
[RANK 5/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.6.0+cu124
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
  Local rank: 5
  Assigned device: cuda:5
  Current device: 5
[NCCL ALL_GATHER RESULTS]
  Input tensor: [5.0, 5.0, 5.0, 5.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
  From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
  From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
  From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
  From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
  From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
  From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
  ✓✓✓ SUCCESS: All tests passed on rank 5
================================================================================

================================================================================
[RANK 6/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.6.0+cu124
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
  Local rank: 6
  Assigned device: cuda:6
  Current device: 6
[NCCL ALL_GATHER RESULTS]
  Input tensor: [6.0, 6.0, 6.0, 6.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
  From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
  From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
  From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
  From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
  From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
  From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
  ✓✓✓ SUCCESS: All tests passed on rank 6
================================================================================

================================================================================
[RANK 0/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.6.0+cu124
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
  Local rank: 0
  Assigned device: cuda:0
  Current device: 0
[NCCL ALL_GATHER RESULTS]
  Input tensor: [0.0, 0.0, 0.0, 0.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
  From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
  From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
  From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
  From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
  From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
  From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
  ✓✓✓ SUCCESS: All tests passed on rank 0
================================================================================

================================================================================
[RANK 1/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.6.0+cu124
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
  Local rank: 1
  Assigned device: cuda:1
  Current device: 1
[NCCL ALL_GATHER RESULTS]
  Input tensor: [1.0, 1.0, 1.0, 1.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
  From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
  From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
  From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
  From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
  From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
  From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
  ✓✓✓ SUCCESS: All tests passed on rank 1
================================================================================

================================================================================
[RANK 7/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.6.0+cu124
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
  Local rank: 7
  Assigned device: cuda:7
  Current device: 7
[NCCL ALL_GATHER RESULTS]
  Input tensor: [7.0, 7.0, 7.0, 7.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
  From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
  From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
  From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
  From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
  From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
  From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
  ✓✓✓ SUCCESS: All tests passed on rank 7
================================================================================

================================================================================
[RANK 4/8] NCCL DEBUG SUMMARY
================================================================================
[ENVIRONMENT]
  CUDA_VISIBLE_DEVICES: 1,2,3,4,5,6,7,0
  NCCL_P2P_DISABLE: 1
  NCCL environment variables:
    NCCL_P2P_DISABLE: 1
    NCCL_VERSION: 2.17.1-1
    NV_LIBNCCL_DEV_PACKAGE: libnccl-dev=2.17.1-1+cuda12.1
    NV_LIBNCCL_DEV_PACKAGE_NAME: libnccl-dev
    NV_LIBNCCL_DEV_PACKAGE_VERSION: 2.17.1-1
    NV_LIBNCCL_PACKAGE: libnccl2=2.17.1-1+cuda12.1
    NV_LIBNCCL_PACKAGE_NAME: libnccl2
    NV_LIBNCCL_PACKAGE_VERSION: 2.17.1-1
    TORCH_NCCL_ASYNC_ERROR_HANDLING: 1
[PYTORCH & CUDA]
  PyTorch version: 2.6.0+cu124
  CUDA available: True
  CUDA device count: 8
  NCCL version: (2, 21, 5)
[DEVICE ASSIGNMENT]
  Local rank: 4
  Assigned device: cuda:4
  Current device: 4
[NCCL ALL_GATHER RESULTS]
  Input tensor: [4.0, 4.0, 4.0, 4.0]
  From rank 0: [0.0, 0.0, 0.0, 0.0] (expected 0.0, got 0.0) - ✓
  From rank 1: [1.0, 1.0, 1.0, 1.0] (expected 1.0, got 1.0) - ✓
  From rank 2: [2.0, 2.0, 2.0, 2.0] (expected 2.0, got 2.0) - ✓
  From rank 3: [3.0, 3.0, 3.0, 3.0] (expected 3.0, got 3.0) - ✓
  From rank 4: [4.0, 4.0, 4.0, 4.0] (expected 4.0, got 4.0) - ✓
  From rank 5: [5.0, 5.0, 5.0, 5.0] (expected 5.0, got 5.0) - ✓
  From rank 6: [6.0, 6.0, 6.0, 6.0] (expected 6.0, got 6.0) - ✓
  From rank 7: [7.0, 7.0, 7.0, 7.0] (expected 7.0, got 7.0) - ✓
[STATUS]
  ✓✓✓ SUCCESS: All tests passed on rank 4
================================================================================

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions