@@ -110,7 +110,7 @@ bool ncclCeImplemented(ncclFunc_t coll, int/*ncclDevRedOp_t*/ red, ncclDataType_
110110 return false ;
111111}
112112
113- ncclResult_t ncclPrepMCSync (struct ncclComm * comm, bool isComplete, CUstreamBatchMemOpParams * batchParams, size_t * opIdx, cudaStream_t stream) {
113+ ncclResult_t ncclPrepMCSync (struct ncclComm * comm, bool isComplete, hipStreamBatchMemOpParams * batchParams, size_t * opIdx, cudaStream_t stream) {
114114 ncclResult_t ret = ncclSuccess;
115115
116116 uint32_t * readyPtrs = (uint32_t *)comm->ceColl .baseUCSymReadyPtr ;
@@ -142,7 +142,7 @@ ncclResult_t ncclPrepMCSync(struct ncclComm* comm, bool isComplete, CUstreamBatc
142142 for (int r = 0 ; r < comm->nRanks ; ++r) {
143143 if (r == comm->rank ) continue ;
144144 batchParams[*opIdx] = {};
145- batchParams[*opIdx].waitValue .operation = CU_STREAM_MEM_OP_WAIT_VALUE_32 ;
145+ // batchParams[*opIdx].waitValue.operation = HIP_STREAM_MEM_OP_WAIT_VALUE_32 ;
146146 batchParams[*opIdx].waitValue .address = (CUdeviceptr)(isComplete ? (void *)&completePtrs[r] : (void *)&readyPtrs[r]);
147147 batchParams[*opIdx].waitValue .value = waitValue;
148148 batchParams[*opIdx].waitValue .flags = CU_STREAM_WAIT_VALUE_EQ;
@@ -156,7 +156,7 @@ ncclResult_t ncclPrepMCSync(struct ncclComm* comm, bool isComplete, CUstreamBatc
156156}
157157
158158ncclResult_t ncclPrepUCSync (struct ncclComm * comm, bool isComplete,
159- CUstreamBatchMemOpParams * batchParams,
159+ hipStreamBatchMemOpParams * batchParams,
160160 size_t * opIdx) {
161161 ncclResult_t ret = ncclSuccess;
162162
@@ -175,7 +175,7 @@ ncclResult_t ncclPrepUCSync(struct ncclComm* comm, bool isComplete,
175175 size_t offset = (uint8_t *)dstPtr - (uint8_t *)comm->ceColl .ceSyncWin ->userPtr ;
176176 NCCLCHECKGOTO (ncclDevrGetLsaRankPtr (comm, comm->ceColl .ceSyncWin , offset, r, &peerDstPtr), ret, fail);
177177 batchParams[*opIdx] = {};
178- batchParams[*opIdx].writeValue .operation = CU_STREAM_MEM_OP_WRITE_VALUE_32 ;
178+ batchParams[*opIdx].writeValue .operation = hipStreamMemOpWriteValue32 ;
179179 batchParams[*opIdx].writeValue .address = (CUdeviceptr)peerDstPtr;
180180 batchParams[*opIdx].writeValue .value = waitValue;
181181 // batchParams[*opIdx].writeValue.flags = CU_STREAM_WRITE_VALUE_DEFAULT;
@@ -186,7 +186,7 @@ ncclResult_t ncclPrepUCSync(struct ncclComm* comm, bool isComplete,
186186 for (int r = 0 ; r < comm->nRanks ; ++r) {
187187 if (r == comm->rank ) continue ;
188188 batchParams[*opIdx] = {};
189- batchParams[*opIdx].waitValue .operation = CU_STREAM_MEM_OP_WAIT_VALUE_32 ;
189+ // batchParams[*opIdx].waitValue.operation = HIP_STREAM_MEM_OP_WAIT_VALUE_32 ;
190190 batchParams[*opIdx].waitValue .address = (CUdeviceptr)(isComplete ? (void *)&completePtrs[r] : (void *)&readyPtrs[r]);
191191 batchParams[*opIdx].waitValue .value = waitValue;
192192 batchParams[*opIdx].waitValue .flags = CU_STREAM_WAIT_VALUE_EQ;
@@ -212,7 +212,7 @@ ncclResult_t ncclMemOpSync(struct ncclComm* comm, cudaStream_t stream) {
212212 size_t opIdx = 0 ;
213213
214214 // Prepare batch memory operations for synchronization
215- CUstreamBatchMemOpParams * batchParams = nullptr ;
215+ hipStreamBatchMemOpParams * batchParams = nullptr ;
216216 NCCLCHECKGOTO (ncclCalloc (&batchParams, batchSize), ret, fail);
217217
218218 if (comm->nvlsSupport ) {
@@ -225,7 +225,7 @@ ncclResult_t ncclMemOpSync(struct ncclComm* comm, cudaStream_t stream) {
225225 if (ncclCudaGraphValid (comm->planner .capturingGraph )) {
226226 for (int i = 0 ; i < comm->nRanks ; i++) {
227227 batchParams[opIdx] = {};
228- batchParams[opIdx].writeValue .operation = CU_STREAM_MEM_OP_WRITE_VALUE_32 ;
228+ batchParams[opIdx].writeValue .operation = hipStreamMemOpWriteValue32 ;
229229 batchParams[opIdx].writeValue .address = (CUdeviceptr)(comm->ceColl .useCompletePtr ? (void *)&completePtrs[i] : (void *)&readyPtrs[i]);
230230 batchParams[opIdx].writeValue .value = 0 ;
231231 // batchParams[opIdx].writeValue.flags = CU_STREAM_WRITE_VALUE_DEFAULT;
@@ -234,7 +234,7 @@ ncclResult_t ncclMemOpSync(struct ncclComm* comm, cudaStream_t stream) {
234234 }
235235
236236 // Execute all memory operations in a single batch
237- CUDACHECKGOTO (cuStreamBatchMemOp (stream, opIdx, batchParams, 0 ), ret, fail);
237+ CUDACHECKGOTO (hipStreamBatchMemOp (stream, opIdx, batchParams, 0 ), ret, fail);
238238
239239 // Toggle the flag for next call
240240 comm->ceColl .useCompletePtr = !comm->ceColl .useCompletePtr ;
0 commit comments