Skip to content

BUG: system.init_data_model_parallel() Prevents Nsight Systems (nsys) from Tracing GPU Hardware in Distributed Mode #315

@ziyuhuang123

Description

@ziyuhuang123

Environment Details

Tutel Version: v0.4 (Installed from @main branch)

PyTorch Version: 2.5

CUDA Driver Version: 550.54.15 (Supports CUDA 12.4)

Nsight Systems CLI Version: 2023.4.4

GPU Configuration: 2x A100 80GB

Observed Behavior

When running the distributed benchmark (tutel/examples/helloworld.py) launched via torch.distributed.run and profiled with nsys, the resulting report fails to collect all GPU Hardware (HW) timelines.

The report successfully traces CPU activity and high-level NCCL API calls.

The timelines for GPU 0 / GPU 1, Kernels (GEMM computation), and Memory [C2C] (All-to-All communication) are missing.

Expected Behavior

The nsys report should display the Kernels and Memory [C2C] hardware timelines, running concurrently to analyze communication-computation overlap.

Steps to Reproduce (Minimal Example)

The issue is reproducible by launching the original script with required profiling flags:

  1. Launch Command (Failing Trace)
    nsys profile --trace-fork-before-exec=true -o tutel_fail.nsys
    /path/to/python3 -m torch.distributed.run --nproc-per-node=2 ~/tutel_examples/helloworld.py
  2. Workaround / Proof of Concept (Successful Trace)
    The issue is resolved by replacing Tutel's custom initialization with standard PyTorch initialization.

Change helloworld.py (Approx. Lines 53-56):

The block containing parallel_env = system.init_data_model_parallel(...) must be replaced by standard initialization.

# CODE BLOCK TO REPLACE TUTEL'S INIT:
import torch.distributed as dist
import os
# ... (inside script)
dist.init_process_group(backend='nccl' if args.device == 'cuda' else 'gloo')
local_rank = int(os.environ['LOCAL_RANK'])
device = torch.device(f'cuda:{local_rank}')
torch.cuda.set_device(device)
# Define Tutel's helper variables here (dist_rank, dist_world_size, dist_print = print)

Root Cause Analysis

The problem is a conflict between Tutel's custom initialization and the CUDA Profiling Tools Interface (CUPTI) during multi-process injection.

tutel.system.init_data_model_parallel() executes advanced CUDA/NCCL setup very early in the child process lifetime.

When nsys attempts to hook the hardware via CUPTI, Tutel's non-standard, early initialization conflicts with the timing required for CUPTI to successfully mount the hardware sampling probes into the child processes.

This results in a silent failure where nsys believes it has attached, but the low-level hardware tracing mechanism is disabled for the workers.

Suggested Fix

It is recommended to review the implementation of tutel/system.py's init_data_model_parallel() to ensure that custom CUDA operations or stream/group creations do not prematurely interfere with the CUPTI (CUDA Profiling Tools Interface) initialization window.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions