Skip to content

Conversation

@ahmd-k
Copy link
Contributor

@ahmd-k ahmd-k commented Nov 19, 2025

Context

PyTorch Distributed's ProcessGroupNCCL has a watchdog thread that detects collective hangs or errors. For Meta's NCCLX and RCCLX, this triggers a custom NCCL API we have defined called ncclCommDump(), which prints data from our ProxyTrace. This is useful because we can't rely on the ProxyTrace print in commFree(), since PyTorch Distributed does not always clean up all the communicators during NCCL errors.

We want to open-source this functionality so that PTD jobs on open-source RCCL can have the same functionality.

Implementation Details

  • In a new file, add barebones ncclCommDump() API that dumps the ProxyTrace contents, similar to our current behavior in commFree()
    • Add ncclCommDump() to nccl.h
    • Modify logs so we can distringuish between ncclCommDump and commFree output
  • Introduce timestamps for post and lastSend, and add them to the ProxyTrace dump
    • Add a map from proxy op stage to timestamp
    • Add helper function for setting the timestamps
    • Add calls to set the timestamps in transport/net.cc
  • Made the public APIs for ProxyTrace thread-safe
    • Instead of having the APIs as non-member functions that called ProxyTrace member functions underneath, moved them into the class. This addition layer of indirection was confusing and made it difficult to synchornize the APIs with a shared mutex.
      • proxyTraceInit() -> ProxyTrace::ProxyTrace()
      • updateProxyOpCounter()
      • setProxyOpTimestamp()
      • addNewProxyOp()
    • Added a lock acquire at the start of each public API call.
    • Made internal implementation details like addNewProxyTraceOpImpl() private. These will not be called by other RCCL code directly and should not be public
    • Updated unit tests
      • Removed resetAll() and its corresponding unit test. This is not used anywhere.
      • Removed unnecessary testing of getOrCreateProxyOpId(). We should just test the public API calling it underneath (addNewProxyOp())

Testing

Testing these changes requires a multi-node setup. We did this internally with a 2-node PyTorch Distributed job, as well as a 2-node MPI test. We have a somewhat convoluted Python script for doing this, but the final run arguments look like this

" ".join(
    [
        f"{MPIRUN}",
        f"-x LD_LIBRARY_PATH={OMPI_LIB}:{RCCL_BUILD_DIR_LIB}",
        f"-np {total_gpus}",
        f"-host {hosts_with_slots}",
        "--allow-run-as-root",
        "-x RCCL_ENABLE_PROXY_TRACE=1",
        "-x NCCL_DEBUG=WARN",
        "-x HSA_NO_SCRATCH_RECLAIM=1",
        f"{rccl_test_binary_path} -b {args.minbytes} -e {args.maxbytes} -g 1 -f 2",
    ]
),

This is what the logs look like at the end when the communicator is dumped.

[0] ~/rccl/build/release/hipify/src/init.cc:445 NCCL WARN commDump for all active ops mapSizeMB:0.01
createT:1763627168308, lastT:1763627168328, postT:0, sendT:0, cntNm:7, <1502607184520987562:1367:0>, [fu,pr,pa,tb,ck]:4,2,1,0,2097152, 0->8(R), chan:53, status:DONE, ns:240, nb:1048576, po:240, ke:0, tail/h:39240, recvT:39240, connSz/h:39238, trans:240, flushed:240, recvd:240, done:240
createT:1763627168308, lastT:1763627168329, postT:1763627168328, sendT:1763627168328, cntNm:7, <1502607184520987562:1367:10>, [fu,pr,pa,tb,ck]:4,2,1,0,2097152, 0->8(S), chan:21, status:DONE, ns:240, nb:1048576, po:240, ke:240, tail/h:39238, recvT:39240, connSz/h:1048576, trans:240, flushed:0, recvd:0, done:240
createT:1763627168308, lastT:1763627168329, postT:1763627168328, sendT:1763627168329, cntNm:7, <1502607184520987562:1367:11>, [fu,pr,pa,tb,ck]:4,2,1,0,2097152, 0->8(S), chan:42, status:DONE, ns:240, nb:1048576, po:240, ke:240, tail/h:39238, recvT:39240

We also pass the unit tests

$ sudo ONLY_FUNCS="AllReduce * * Sum f32" ./install.sh --local_gpu_only --disable-msccl-kernel --prefix /tmp/rccl 
$ UT_DATATYPES=ncclfloat32 UT_REDOPS=prod /tmp/rccl/bin/rccl-UnitTests --gtest_filter="ProxyTraceTestFixture.*"
Note: Google Test filter = ProxyTraceTestFixture.*
[==========] Running 5 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 5 tests from ProxyTraceTestFixture
[ RUN      ] ProxyTraceTestFixture.nonEmptySingleton
[       OK ] ProxyTraceTestFixture.nonEmptySingleton (0 ms)
[ RUN      ] ProxyTraceTestFixture.addTraceOp
[       OK ] ProxyTraceTestFixture.addTraceOp (0 ms)
[ RUN      ] ProxyTraceTestFixture.getMapSizeMB
[       OK ] ProxyTraceTestFixture.getMapSizeMB (0 ms)
[ RUN      ] ProxyTraceTestFixture.updateTraceOp
[       OK ] ProxyTraceTestFixture.updateTraceOp (0 ms)
[ RUN      ] ProxyTraceTestFixture.updateTraceOp2
[       OK ] ProxyTraceTestFixture.updateTraceOp2 (0 ms)
[----------] 5 tests from ProxyTraceTestFixture (0 ms total)

[----------] Global test environment tear-down
[==========] 5 tests from 1 test suite ran. (0 ms total)
[  PASSED  ] 5 tests.

References

ProcessGroupNCCL dump call: https://github.com/pytorch/pytorch/blob/fcc78410a8e51107a7f4a15431e57da137741aee/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L400-L412

return;
}

traceOpPtr->timestamps[counter] = std::chrono::high_resolution_clock::now();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be pedantic with timing, you may want to sandwich timing calls between atomic signal fences or employ the DoNotOptimize builtin. The compiler does not guarantee that timing calls aren't reordered relative to other blocks of code.

__attribute__ ((visibility("default")))
ncclResult_t ncclCommDump(
const ncclComm_t comm,
std::unordered_map<std::string, std::string>& map) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make const std::unordered_map<std::string, std::string>& map

Copy link
Contributor

@alex-breslow-amd alex-breslow-amd Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have missed it, but why not pass in ncclComm_t as a const reference rather than passing it by value?

Copy link
Contributor Author

@ahmd-k ahmd-k Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make const std::unordered_map<std::string, std::string>& map

In NCCLX and RCCLX, this map is not const because it is where we store some structured trace data that callers like PyTorch can use.

See the NCCLX API: https://github.com/meta-pytorch/torchcomms/blob/fe4e8116f2107b5aed0e38db10e072471ea95126/comms/ncclx/v2_27/meta/commDump.cc#L219

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have missed it, but why not pass in ncclComm_t as a const reference rather than passing it by value?

Good point, I was just following the NCCLX implementation but I don't see why we copy the communicator here. @dmwu @YulunW any idea why the the communicator is passed by value in ncclCommDump()?

@ahmd-k ahmd-k marked this pull request as draft November 21, 2025 07:25
@ahmd-k
Copy link
Contributor Author

ahmd-k commented Nov 21, 2025

I've updated the PR to fix issues with thread safety.

@ahmd-k ahmd-k marked this pull request as ready for review November 21, 2025 10:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants