Skip to content

Commit 718ded6

Browse files
[mlir-tensorrt] NFC: [executor] update NCCL debug logging in runtime
Adds some additional debugging logging in the NCCL executor module. GitOrigin-RevId: b628f4613e87e148ed444342932dc1cf909851e3
1 parent d37e4f8 commit 718ded6

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/CUDA/CUDAModule.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
227227
};
228228

229229
lua["__cuda_stream_sync"] = [](sol::this_state state, CudaStreamPtr stream) {
230+
MTRT_DBG("__cuda_stream_sync @ {0}", reinterpret_cast<void *>(stream.ptr));
230231
ADD_CUDA_MODULE_RANGE("cuda_stream_sync");
231232
SET_LUA_ERROR_IF_CUDART_ERROR(cudaStreamSynchronize(stream), state);
232233
};
@@ -439,7 +440,8 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
439440
size_t srcOffset, uintptr_t dest, size_t destOffset,
440441
size_t numBytes) {
441442
ADD_CUDA_MODULE_RANGE("cuda_memcpy_host_pinned2device");
442-
MTRT_DBGF("cuda_memcpy_h2d %lu bytes from 0x%lx + %lu to 0x%lx + %lu",
443+
MTRT_DBGF("__cuda_memcpy_host_pinned2device: %lu bytes from 0x%lx + "
444+
"%lu to 0x%lx + %lu",
443445
numBytes, src, srcOffset, dest, destOffset);
444446
void *srcPtr = reinterpret_cast<void *>(src + srcOffset);
445447
void *dstPtr = reinterpret_cast<void *>(dest + destOffset);
@@ -475,7 +477,9 @@ registerCudaMemoryManagementOps(sol::state_view &lua,
475477
"expected src to be a device ptr and dest to be a host ptr");
476478
}
477479
#endif
478-
MTRT_DBGF("executor_memcpy device-host %lu bytes", numBytes);
480+
MTRT_DBGF("__cuda_memcpy_device2host_pinned: %lu bytes from 0x%lx + "
481+
"%lu to 0x%lx + %lu",
482+
numBytes, src, srcOffset, dest, destOffset);
479483
SET_LUA_ERROR_IF_CUDART_ERROR(cudaMemcpyAsync(dstPtr, srcPtr, numBytes,
480484
cudaMemcpyDeviceToHost,
481485
stream),

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/NCCL/NCCLModule.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) {
268268
lua["__nccl_all_reduce_" #opsuffix "_" #typesuffix] = \
269269
[](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, \
270270
size_t count, uintptr_t communicator, CudaStreamPtr stream) { \
271+
MTRT_DBG("__nccl_all_reduce_" #opsuffix "_" #typesuffix \
272+
": count={0} send={1} recv={2}", \
273+
count, reinterpret_cast<void *>(sendbuff), \
274+
reinterpret_cast<void *>(recvbuff)); \
271275
auto comm = reinterpret_cast<NcclCommunicator *>(communicator); \
272276
SET_LUA_ERROR_IF_NCCL_ERROR( \
273277
ncclAllReduce(reinterpret_cast<void *>(sendbuff), \
@@ -285,6 +289,10 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) {
285289
lua["__nccl_reduce_scatter_" #opsuffix "_" #typesuffix] = \
286290
[](sol::this_state state, ExecPtr sendbuff, ExecPtr recvbuff, \
287291
size_t recvcount, uintptr_t communicator, CudaStreamPtr stream) { \
292+
MTRT_DBG("__nccl_reduce_scatter_" #opsuffix "_" #typesuffix \
293+
": count={0} sendbuff={1} recvbuff={2}", \
294+
recvcount, reinterpret_cast<void *>(sendbuff), \
295+
reinterpret_cast<void *>(recvbuff)); \
288296
auto *comm = reinterpret_cast<NcclCommunicator *>(communicator); \
289297
SET_LUA_ERROR_IF_NCCL_ERROR( \
290298
ncclReduceScatter(reinterpret_cast<void *>(sendbuff), \
@@ -338,6 +346,13 @@ static void registerNcclOps(sol::state_view &lua, ResourceTracker *tracker) {
338346
size_t numBytes, uintptr_t communicator,
339347
CudaStreamPtr stream) {
340348
auto *comm = reinterpret_cast<NcclCommunicator *>(communicator);
349+
MTRT_DBG("__nccl_permute[{6}/{7}]: send {0} bytes @ {1} to {2}, recv {0} "
350+
"bytes @ "
351+
"{3} from {4}, comm @{5}",
352+
numBytes, reinterpret_cast<void *>(sendbuff), sendId,
353+
reinterpret_cast<void *>(recvbuff), recvId,
354+
reinterpret_cast<void *>(comm->comm), comm->rank, comm->numRanks);
355+
341356
if (recvId == -1) {
342357
// Zero out recvbuff if not receiving.
343358
SET_LUA_ERROR_IF_CUDA_ERROR(

0 commit comments

Comments
 (0)