diff --git a/exllamav2/exllamav2_ext/ext_tp.cpp b/exllamav2/exllamav2_ext/ext_tp.cpp index 374f8a9a..9d6a516d 100644 --- a/exllamav2/exllamav2_ext/ext_tp.cpp +++ b/exllamav2/exllamav2_ext/ext_tp.cpp @@ -177,8 +177,6 @@ void tp_broadcast cudaStream_t stream = ctx->streams[dev]; cuda_check(cudaMemcpyAsync(target, ctx->pinned_temp[buffer], size, cudaMemcpyHostToDevice, stream)); } - - tp_cross_device_barrier(tp_context, broadcast_type, t_device); } void tp_gather @@ -349,19 +347,18 @@ void tp_cross_device_barrier { int dev_i = ctx->all_devices[i]; cudaSetDevice(dev_i); - cuda_check(cudaEventRecord(ctx->sync_events[dev_i], ctx->streams[dev_i])); - } - - for (int i = 0; i < ctx->all_devices.size(); ++i) - { - for (int j = 0; j < ctx->all_devices.size(); ++j) - { - if (i == j) continue; - int dev_i = ctx->all_devices[i]; + if (i > 0) { + int j = i - 1; int dev_j = ctx->all_devices[j]; - cudaSetDevice(dev_i); cuda_check(cudaStreamWaitEvent(ctx->streams[dev_i], ctx->sync_events[dev_j], 0)); } + cuda_check(cudaEventRecord(ctx->sync_events[dev_i], ctx->streams[dev_i])); + } + for (int i = 0; i < ctx->all_devices.size() - 1; ++i) + { + int dev_i = ctx->all_devices[i]; + cudaSetDevice(dev_i); + cuda_check(cudaStreamWaitEvent(ctx->streams[dev_i], ctx->sync_events[ctx->all_devices.back()], 0)); } }