From 3f67949be7303f129d2e889aa143d0b091b6d103 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 2 Sep 2021 10:55:47 +0000 Subject: [PATCH 01/13] add --- bagua-core-internal/kernels/bagua_kernels.cu | 22 +++-- ...centralized_full_precision_asynchronous.rs | 89 +++++++++++-------- bagua-core-internal/src/comm_ops/mod.rs | 7 ++ bagua-core-internal/src/datatypes/mod.rs | 50 ++++++++++- bagua-core-internal/src/kernels/mod.rs | 6 ++ bagua-core-py/src/lib.rs | 43 +++++++-- 6 files changed, 164 insertions(+), 53 deletions(-) diff --git a/bagua-core-internal/kernels/bagua_kernels.cu b/bagua-core-internal/kernels/bagua_kernels.cu index fed68c5..5ce1c61 100644 --- a/bagua-core-internal/kernels/bagua_kernels.cu +++ b/bagua-core-internal/kernels/bagua_kernels.cu @@ -257,11 +257,18 @@ __global__ void divide_inplace_f16(__half *x, float D_, int N) { __global__ void async_model_average(float *tensor, const float *reduced_tensor_copy, const float *tensor_copy, const float nranks, const int N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { -// tensor[i] += reduced_tensor_copy[i] / nranks - tensor_copy[i]; -// if (tensor[i] != tensor[i]) { -// printf("nan encountered!"); -// } - atomicAdd(&tensor[i], reduced_tensor_copy[i] / nranks - tensor_copy[i]); + + tensor[i] += reduced_tensor_copy[i] / nranks - tensor_copy[i]; + if (tensor[i] != tensor[i]) { + printf("nan encountered!"); + } +// atomicAdd(&tensor[i], reduced_tensor_copy[i] / nranks - tensor_copy[i]); + } +} + +__global__ void fill(float *tensor, const float value, const int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { + tensor[i] = value; } } @@ -627,6 +634,11 @@ void async_model_average_host(float *tensor, const float *reduced_tensor_copy, CUDACHECK(cudaGetLastError()); } +void fill_host(float *tensor, float value, const int N, cudaStream_t stream) { + fill<<>>(tensor, value, N); + CUDACHECK(cudaGetLastError()); +} + //// decentralize, recvbuf should get the average of sendbuf and peer's sendbuf //ncclResult_t ncclPeerAverage(void *sendbuf, void *recvbuf, size_t sendcount, // int peer_rank, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index 92f2ba5..d0cb246 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -1,7 +1,7 @@ use crate::comm_ops::decentralized_full_precision_synchronous::PeerSelectionMode; use crate::comm_ops::CommOpTrait; use crate::communicators::BaguaCommunicator; -use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor}; +use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor, BaguaTensor}; use crate::events::BaguaEventChannel; use crate::resource_pool::{CUDA_DEVICE_MEMORY_POOL, CUDA_EVENT_POOL}; use crate::{BaguaCommOpChannels, BaguaCoreError}; @@ -13,6 +13,7 @@ pub struct DecentralizedFullPrecisionAsynchronous { pub communicator: BaguaCommunicator, pub peer_selection_mode: PeerSelectionMode, pub torch_stream: u64, + pub diff_tensor: BaguaTensor, } impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { @@ -72,11 +73,23 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { pool_allocations: vec![Arc::new(reduced_buf)], }; + let src_ready_event = CUDA_EVENT_POOL.take().event; + let dst_ready_event = CUDA_EVENT_POOL.take().event; + + unsafe { + cpp::cpp!([ + dst_ready_event as "cudaEvent_t", + comm_stream as "cudaStream_t", + torch_stream as "cudaStream_t"] + { + CUDACHECK(cudaEventRecord(dst_ready_event, comm_stream)); + CUDACHECK(cudaStreamWaitEvent(torch_stream, dst_ready_event , 0)); + }); + } + // use default stream to copy weights temp_tensor.clone_from(&t.raw, torch_stream as u64); - let src_ready_event = CUDA_EVENT_POOL.take().event; - unsafe { cpp::cpp!([ src_ready_event as "cudaEvent_t", @@ -104,41 +117,16 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { } }; - let comm_ready_event = CUDA_EVENT_POOL.take().event; - - unsafe { - cpp::cpp!([ - comm_ready_event as "cudaEvent_t", - comm_stream as "cudaStream_t"] - { - CUDACHECK(cudaEventRecord(comm_ready_event, comm_stream)); - CUDACHECK(cudaEventSynchronize(comm_ready_event)); - }); - } - - if c.check_abort() { - return; - } - - // do we need to wait default stream? - unsafe { - cpp::cpp!([ - src_ready_event as "cudaEvent_t", - comm_stream as "cudaStream_t", - torch_stream as "cudaStream_t"] - { - CUDACHECK(cudaEventRecord(src_ready_event, torch_stream)); - CUDACHECK(cudaStreamWaitEvent(comm_stream, src_ready_event , 0)); - }); + { + let mut guard = self.diff_tensor.inner.write(); + guard.raw.async_model_average( + &reduced_tensor, + &temp_tensor, + c.nranks as f32, + comm_stream, + ); } - t.raw.async_model_average( - &reduced_tensor, - &temp_tensor, - c.nranks as f32, - comm_stream, - ); - unsafe { cpp::cpp!([comm_stream as "cudaStream_t"] { @@ -153,4 +141,33 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }, ); } + + fn execute_post_step( + &self, + bucket: Arc + ) { + let bucket_guard = bucket.inner.lock(); + let stream_ptr = self.communicator.stream_ptr(); + + let mut communication_tensor = match &self.communicator { + BaguaCommunicator::SingleCommunicator(_) => { + bucket_guard.get_communication_tensor(stream_ptr, false, false) + } + BaguaCommunicator::HierarchicalCommunicator(x) => { + panic!("asynchronous op only accepts non-hierarchical communicator"); + } + }; + + tracing::debug!("async model average update weight start"); + + let t = &mut communication_tensor; + + let mut guard = self.diff_tensor.inner.write(); + + t.raw.add_inplace(guard.raw.as_ref(), stream_ptr); + guard.raw.fill(0.0, stream_ptr); + + tracing::debug!("async model average update weight end"); + } + } diff --git a/bagua-core-internal/src/comm_ops/mod.rs b/bagua-core-internal/src/comm_ops/mod.rs index f22aa73..3779fe1 100644 --- a/bagua-core-internal/src/comm_ops/mod.rs +++ b/bagua-core-internal/src/comm_ops/mod.rs @@ -16,4 +16,11 @@ pub trait CommOpTrait: Debug { bucket: Arc, comm_channels: &BaguaCommOpChannels, ); + + fn execute_post_step( + &self, + bucket: Arc, + ) { + } } + diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index b7376eb..6f44838 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -6,7 +6,7 @@ use crate::comm_ops::decentralized_full_precision_synchronous::{ }; use crate::comm_ops::decentralized_low_precision_synchronous::DecentralizedLowPrecisionSynchronous; use crate::comm_ops::python_ffi_op::PythonFFIOp; -use crate::comm_ops::CommOpTrait; +use crate::comm_ops::CommOpTrait; use crate::communicators::{BaguaCommunicator, BaguaSingleCommunicator}; use crate::resource_pool::{CudaMemory, CUDA_DEVICE_MEMORY_POOL}; use crate::torch_ffi::root::c10::{DeviceType, StorageImpl, TensorImpl}; @@ -164,6 +164,24 @@ pub trait RawBaguaTensor: Debug { ); } } + + fn fill( + &mut self, + value: f32, + stream_ptr: u64, + ) { + let tensor_ptr = self.data_ptr(); + let total_num_elem = self.num_elements(); + unsafe { + kernels::fill_host( + tensor_ptr as _, + value as f32, + total_num_elem as i32, + stream_ptr as _, + ); + } + } + fn substract_inplace(&mut self, other: &dyn RawBaguaTensor, stream_ptr: u64) { assert_eq!(self.dtype(), other.dtype()); assert_eq!(self.num_elements(), other.num_elements()); @@ -1069,6 +1087,25 @@ impl<'b> Drop for BaguaCommunicationTensor<'b> { } } +#[derive(Debug, Clone)] +pub struct BaguaCommOp { + pub name: String, + pub inner: Arc +} + +impl BaguaCommOp { + pub fn execute_post_step( + &self, + bucket: &BaguaBucket, + ) -> Result<(), BaguaCoreError> { + + let bucket = Arc::new((*bucket).clone()); + self.inner.execute_post_step(bucket); + + Ok(()) + } +} + #[derive(Debug, Clone)] pub struct BaguaBucket { pub name: String, @@ -1232,7 +1269,8 @@ impl BaguaBucket { communicator_intranode: Option<&BaguaSingleCommunicator>, peer_selection_mode: String, torch_stream: u64, - ) { + diff_tensor: BaguaTensor, + ) -> BaguaCommOp { let communicator = BaguaCommunicator::new(communicator_internode, communicator_intranode, false) .expect("cannot create communicator"); @@ -1247,10 +1285,16 @@ impl BaguaBucket { } }, torch_stream, + diff_tensor, }, ); - self.inner.lock().comm_ops.push(comm_op); + self.inner.lock().comm_ops.push(comm_op.clone()); + + BaguaCommOp { + name: String::from("decentralized_async_op"), + inner: comm_op, + } } pub fn ready_for_comm(&self) -> bool { diff --git a/bagua-core-internal/src/kernels/mod.rs b/bagua-core-internal/src/kernels/mod.rs index 55df19e..1fa7ec3 100644 --- a/bagua-core-internal/src/kernels/mod.rs +++ b/bagua-core-internal/src/kernels/mod.rs @@ -134,4 +134,10 @@ extern "C" { N: i32, stream: *const c_void, ); + pub fn fill_host( + tensor: *mut c_void, + value: f32, + N: i32, + stream: *const c_void, + ); } diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index 28a2eb6..aba2ab6 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -2,7 +2,7 @@ use bagua_core_internal::communicators::BaguaSingleCommunicator; use bagua_core_internal::datatypes::{ - BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorDtype, + BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorDtype, BaguaCommOp }; use bagua_core_internal::BaguaCommBackend; use num_traits::FromPrimitive; @@ -193,6 +193,26 @@ impl BaguaSingleCommunicatorPy { } } +#[pyclass(dict)] +pub struct BaguaCommOpPy { + inner: BaguaCommOp +} + +#[pymethods] +impl BaguaCommOpPy { + + pub fn execute_post_step(&self, + bucket: PyRef, + ) -> PyResult<()> { + let bucket_inner = &bucket.inner; + + self.inner.execute_post_step(bucket_inner) + .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) + + } + +} + #[pyclass(dict)] pub struct BaguaTensorPy { inner: BaguaTensor, @@ -347,6 +367,7 @@ impl BaguaCommBackendPy { py.allow_threads(|| self.inner.wait_pending_comm_ops()) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } + } #[pyclass(dict)] @@ -456,14 +477,17 @@ impl BaguaBucketPy { communicator_intranode: Option<&BaguaSingleCommunicatorPy>, peer_selection_mode: String, torch_stream: u64, - ) -> PyResult<()> { - self.inner.append_decentralized_asynchronous_op( - communicator_internode.map(|x| &x.inner), - communicator_intranode.map(|x| &x.inner), - peer_selection_mode, - torch_stream, - ); - Ok(()) + diff_tensor: PyRef, + ) -> BaguaCommOpPy { + BaguaCommOpPy { + inner: self.inner.append_decentralized_asynchronous_op( + communicator_internode.map(|x| &x.inner), + communicator_intranode.map(|x| &x.inner), + peer_selection_mode, + torch_stream, + (*diff_tensor).inner.clone(), + ) + } } pub fn print_ops(&self) -> PyResult<()> { @@ -507,6 +531,7 @@ fn bagua_core(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; #[pyfn(m, "show_version")] fn show_version(_py: Python) { From b82dc5cc8ab66b9a804f18ece60ed4b600907893 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 3 Sep 2021 15:17:53 +0800 Subject: [PATCH 02/13] tmp save --- ...centralized_full_precision_asynchronous.rs | 84 +++++++++++++------ bagua-core-internal/src/datatypes/mod.rs | 1 + bagua-core-internal/src/lib.rs | 15 +++- bagua-core-py/src/lib.rs | 22 +++-- 4 files changed, 92 insertions(+), 30 deletions(-) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index d0cb246..8449907 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -23,6 +23,8 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { comm_op_channels: &BaguaCommOpChannels, ) { let bucket_guard = bucket.inner.lock(); + let mut tensor_guard = self.diff_tensor.inner.write(); + let comm_stream = self.communicator.stream_ptr(); let mut communication_tensor = match &self.communicator { @@ -45,7 +47,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { false, &mut |c, t| { let start_time = std::time::Instant::now(); - tracing::debug!("async model average start"); + tracing::debug!("#{} async model average start", c.rank); let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()] .try_pull(t.raw.num_elements_allocated() * t.raw.dtype().bytes()) @@ -117,26 +119,24 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { } }; - { - let mut guard = self.diff_tensor.inner.write(); - guard.raw.async_model_average( - &reduced_tensor, - &temp_tensor, - c.nranks as f32, - comm_stream, - ); - } + + tensor_guard.raw.async_model_average( + &reduced_tensor, + &temp_tensor, + c.nranks as f32, + comm_stream, + ); - unsafe { + /*unsafe { cpp::cpp!([comm_stream as "cudaStream_t"] { CUDACHECK(cudaStreamSynchronize(comm_stream)); }); - } + }*/ tracing::debug!( - "async model average update cost: {:?}", - start_time.elapsed() + "#{} async model average update cost: {:?}", + c.rank, start_time.elapsed() ); }, ); @@ -146,28 +146,64 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { &self, bucket: Arc ) { + let bucket_guard = bucket.inner.lock(); - let stream_ptr = self.communicator.stream_ptr(); - + let mut tensor_guard = self.diff_tensor.inner.write(); + + let comm_stream = self.communicator.stream_ptr(); + let mut communication_tensor = match &self.communicator { BaguaCommunicator::SingleCommunicator(_) => { - bucket_guard.get_communication_tensor(stream_ptr, false, false) + bucket_guard.get_communication_tensor(comm_stream, false, false) } BaguaCommunicator::HierarchicalCommunicator(x) => { panic!("asynchronous op only accepts non-hierarchical communicator"); } }; - tracing::debug!("async model average update weight start"); + let torch_stream = self.torch_stream; - let t = &mut communication_tensor; - - let mut guard = self.diff_tensor.inner.write(); + self.communicator.execute_communication( + &mut communication_tensor, + false, + false, + false, + &mut |c, t| { - t.raw.add_inplace(guard.raw.as_ref(), stream_ptr); - guard.raw.fill(0.0, stream_ptr); + tracing::debug!("#{} async model average update weight start", c.rank); + + let src_ready_event = CUDA_EVENT_POOL.take().event; + let dst_ready_event = CUDA_EVENT_POOL.take().event; + + unsafe { + cpp::cpp!([ + src_ready_event as "cudaEvent_t", + comm_stream as "cudaStream_t", + torch_stream as "cudaStream_t"] + { + CUDACHECK(cudaEventRecord(src_ready_event, torch_stream)); + CUDACHECK(cudaStreamWaitEvent(comm_stream, src_ready_event , 0)); + }); + } + + t.raw.add_inplace(tensor_guard.raw.as_ref(), comm_stream); + tensor_guard.raw.fill(0.0, comm_stream); + + unsafe { + cpp::cpp!([ + dst_ready_event as "cudaEvent_t", + comm_stream as "cudaStream_t", + torch_stream as "cudaStream_t"] + { + CUDACHECK(cudaEventRecord(dst_ready_event, comm_stream)); + CUDACHECK(cudaStreamWaitEvent(torch_stream, dst_ready_event , 0)); + }); + } - tracing::debug!("async model average update weight end"); + tracing::debug!("#{} async model average update weight end", c.rank); + }, + ); + } } diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index 6f44838..1b55089 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -1099,6 +1099,7 @@ impl BaguaCommOp { bucket: &BaguaBucket, ) -> Result<(), BaguaCoreError> { + tracing::debug!("BaguaCommOp: execute_post_step"); let bucket = Arc::new((*bucket).clone()); self.inner.execute_post_step(bucket); diff --git a/bagua-core-internal/src/lib.rs b/bagua-core-internal/src/lib.rs index caead44..6c6f014 100644 --- a/bagua-core-internal/src/lib.rs +++ b/bagua-core-internal/src/lib.rs @@ -15,7 +15,7 @@ mod torch_ffi; use crate::comm_ops::CommOpTrait; use bagua_opentelemetry; use cpp::cpp; -use datatypes::{BaguaBucket, BaguaTensor}; +use datatypes::{BaguaBucket, BaguaTensor, BaguaCommOp}; use events::BaguaEventChannel; use flume::RecvTimeoutError; use hashbrown::{HashMap, HashSet}; @@ -335,4 +335,17 @@ impl BaguaCommBackend { } } } + + pub fn execute_post_comm_step( + &self, + bucket: &BaguaBucket, + op: &BaguaCommOp + ) -> Result<(), BaguaCoreError> { + tracing::debug!("bagua-core-internal: schedule_comm_post_step"); + + let bucket = Arc::new((*bucket).clone()); + op.inner.execute_post_step(bucket); + + Ok(()) + } } diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index aba2ab6..413b020 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -201,14 +201,12 @@ pub struct BaguaCommOpPy { #[pymethods] impl BaguaCommOpPy { - pub fn execute_post_step(&self, - bucket: PyRef, - ) -> PyResult<()> { + pub fn execute_post_step(&self, bucket: PyRef) -> PyResult<()> { + tracing::debug!("BaguaCommOpPy: execute_post_step"); let bucket_inner = &bucket.inner; - self.inner.execute_post_step(bucket_inner) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) - + } } @@ -368,6 +366,20 @@ impl BaguaCommBackendPy { .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } + pub fn execute_post_comm_step( + &self, + bucket: PyRef, + op: PyRef, + py: Python + ) -> PyResult<()> { + + let bucket_inner = &bucket.inner; + let op_inner = &op.inner; + tracing::debug!("bagua-core-py: schedule_comm_post_step"); + self.inner.execute_post_comm_step(bucket_inner, op_inner) + .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) + } + } #[pyclass(dict)] From ce564eabba7d6a12807baf707d73a87185df3303 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 3 Sep 2021 17:36:38 +0800 Subject: [PATCH 03/13] update async --- bagua-core-internal/kernels/bagua_kernels.cu | 9 +- ...centralized_full_precision_asynchronous.rs | 90 ++++--------------- bagua-core-internal/src/datatypes/mod.rs | 14 +-- bagua-core-internal/src/kernels/mod.rs | 4 +- bagua-core-internal/src/lib.rs | 14 +-- bagua-core-py/src/lib.rs | 17 +--- 6 files changed, 38 insertions(+), 110 deletions(-) diff --git a/bagua-core-internal/kernels/bagua_kernels.cu b/bagua-core-internal/kernels/bagua_kernels.cu index 5ce1c61..34eaea6 100644 --- a/bagua-core-internal/kernels/bagua_kernels.cu +++ b/bagua-core-internal/kernels/bagua_kernels.cu @@ -266,9 +266,10 @@ __global__ void async_model_average(float *tensor, const float *reduced_tensor_c } } -__global__ void fill(float *tensor, const float value, const int N) { +__global__ void async_model_update(float *tensor, float *diff_tensor, const int N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - tensor[i] = value; + tensor[i] += diff_tensor[i]; + diff_tensor[i] = 0.0; } } @@ -634,8 +635,8 @@ void async_model_average_host(float *tensor, const float *reduced_tensor_copy, CUDACHECK(cudaGetLastError()); } -void fill_host(float *tensor, float value, const int N, cudaStream_t stream) { - fill<<>>(tensor, value, N); +void async_model_update_host(float *tensor, float *diff_tensor, const int N, cudaStream_t stream) { + async_model_update<<>>(tensor, diff_tensor, N); CUDACHECK(cudaGetLastError()); } diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index 8449907..4c86cc1 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -13,6 +13,7 @@ pub struct DecentralizedFullPrecisionAsynchronous { pub communicator: BaguaCommunicator, pub peer_selection_mode: PeerSelectionMode, pub torch_stream: u64, + pub weight: BaguaTensor, pub diff_tensor: BaguaTensor, } @@ -23,7 +24,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { comm_op_channels: &BaguaCommOpChannels, ) { let bucket_guard = bucket.inner.lock(); - let mut tensor_guard = self.diff_tensor.inner.write(); let comm_stream = self.communicator.stream_ptr(); @@ -103,10 +103,10 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }); } - if c.check_abort() { + /* if c.check_abort() { return; } - +*/ match peer_mode { PeerSelectionMode::All => { c.allreduce(&temp_tensor, &mut reduced_tensor, BaguaReductionOp::SUM); @@ -119,20 +119,15 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { } }; - - tensor_guard.raw.async_model_average( - &reduced_tensor, - &temp_tensor, - c.nranks as f32, - comm_stream, - ); - - /*unsafe { - cpp::cpp!([comm_stream as "cudaStream_t"] - { - CUDACHECK(cudaStreamSynchronize(comm_stream)); - }); - }*/ + { + let mut tensor_guard = self.diff_tensor.inner.write(); + tensor_guard.raw.async_model_average( + &reduced_tensor, + &temp_tensor, + c.nranks as f32, + comm_stream, + ); + } tracing::debug!( "#{} async model average update cost: {:?}", @@ -147,63 +142,16 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { bucket: Arc ) { - let bucket_guard = bucket.inner.lock(); + tracing::debug!("async update weight start"); + let torch_stream = self.torch_stream; + let mut tensor_guard = self.diff_tensor.inner.write(); + let mut guard = self.weight.inner.write(); + + guard.raw.async_model_update(tensor_guard.raw.as_ref(), torch_stream); - let comm_stream = self.communicator.stream_ptr(); + tracing::debug!("async update weight end"); - let mut communication_tensor = match &self.communicator { - BaguaCommunicator::SingleCommunicator(_) => { - bucket_guard.get_communication_tensor(comm_stream, false, false) - } - BaguaCommunicator::HierarchicalCommunicator(x) => { - panic!("asynchronous op only accepts non-hierarchical communicator"); - } - }; - - let torch_stream = self.torch_stream; - - self.communicator.execute_communication( - &mut communication_tensor, - false, - false, - false, - &mut |c, t| { - - tracing::debug!("#{} async model average update weight start", c.rank); - - let src_ready_event = CUDA_EVENT_POOL.take().event; - let dst_ready_event = CUDA_EVENT_POOL.take().event; - - unsafe { - cpp::cpp!([ - src_ready_event as "cudaEvent_t", - comm_stream as "cudaStream_t", - torch_stream as "cudaStream_t"] - { - CUDACHECK(cudaEventRecord(src_ready_event, torch_stream)); - CUDACHECK(cudaStreamWaitEvent(comm_stream, src_ready_event , 0)); - }); - } - - t.raw.add_inplace(tensor_guard.raw.as_ref(), comm_stream); - tensor_guard.raw.fill(0.0, comm_stream); - - unsafe { - cpp::cpp!([ - dst_ready_event as "cudaEvent_t", - comm_stream as "cudaStream_t", - torch_stream as "cudaStream_t"] - { - CUDACHECK(cudaEventRecord(dst_ready_event, comm_stream)); - CUDACHECK(cudaStreamWaitEvent(torch_stream, dst_ready_event , 0)); - }); - } - - tracing::debug!("#{} async model average update weight end", c.rank); - }, - ); - } } diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index 1b55089..843651b 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -165,17 +165,20 @@ pub trait RawBaguaTensor: Debug { } } - fn fill( + fn async_model_update( &mut self, - value: f32, + diff_tensor: &dyn RawBaguaTensor, stream_ptr: u64, ) { + assert_eq!(self.dtype(), diff_tensor.dtype()); + assert_eq!(self.num_elements(), diff_tensor.num_elements()); + let tensor_ptr = self.data_ptr(); let total_num_elem = self.num_elements(); unsafe { - kernels::fill_host( + kernels::async_model_update_host( tensor_ptr as _, - value as f32, + diff_tensor.data_ptr() as _, total_num_elem as i32, stream_ptr as _, ); @@ -1099,7 +1102,6 @@ impl BaguaCommOp { bucket: &BaguaBucket, ) -> Result<(), BaguaCoreError> { - tracing::debug!("BaguaCommOp: execute_post_step"); let bucket = Arc::new((*bucket).clone()); self.inner.execute_post_step(bucket); @@ -1270,6 +1272,7 @@ impl BaguaBucket { communicator_intranode: Option<&BaguaSingleCommunicator>, peer_selection_mode: String, torch_stream: u64, + weight: BaguaTensor, diff_tensor: BaguaTensor, ) -> BaguaCommOp { let communicator = @@ -1286,6 +1289,7 @@ impl BaguaBucket { } }, torch_stream, + weight, diff_tensor, }, ); diff --git a/bagua-core-internal/src/kernels/mod.rs b/bagua-core-internal/src/kernels/mod.rs index 1fa7ec3..3fc6ad3 100644 --- a/bagua-core-internal/src/kernels/mod.rs +++ b/bagua-core-internal/src/kernels/mod.rs @@ -134,9 +134,9 @@ extern "C" { N: i32, stream: *const c_void, ); - pub fn fill_host( + pub fn async_model_update_host( tensor: *mut c_void, - value: f32, + diff_tensor: *mut c_void, N: i32, stream: *const c_void, ); diff --git a/bagua-core-internal/src/lib.rs b/bagua-core-internal/src/lib.rs index 6c6f014..786d349 100644 --- a/bagua-core-internal/src/lib.rs +++ b/bagua-core-internal/src/lib.rs @@ -15,7 +15,7 @@ mod torch_ffi; use crate::comm_ops::CommOpTrait; use bagua_opentelemetry; use cpp::cpp; -use datatypes::{BaguaBucket, BaguaTensor, BaguaCommOp}; +use datatypes::{BaguaBucket, BaguaTensor}; use events::BaguaEventChannel; use flume::RecvTimeoutError; use hashbrown::{HashMap, HashSet}; @@ -336,16 +336,4 @@ impl BaguaCommBackend { } } - pub fn execute_post_comm_step( - &self, - bucket: &BaguaBucket, - op: &BaguaCommOp - ) -> Result<(), BaguaCoreError> { - tracing::debug!("bagua-core-internal: schedule_comm_post_step"); - - let bucket = Arc::new((*bucket).clone()); - op.inner.execute_post_step(bucket); - - Ok(()) - } } diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index 413b020..66dd006 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -202,7 +202,6 @@ pub struct BaguaCommOpPy { impl BaguaCommOpPy { pub fn execute_post_step(&self, bucket: PyRef) -> PyResult<()> { - tracing::debug!("BaguaCommOpPy: execute_post_step"); let bucket_inner = &bucket.inner; self.inner.execute_post_step(bucket_inner) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) @@ -366,20 +365,6 @@ impl BaguaCommBackendPy { .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } - pub fn execute_post_comm_step( - &self, - bucket: PyRef, - op: PyRef, - py: Python - ) -> PyResult<()> { - - let bucket_inner = &bucket.inner; - let op_inner = &op.inner; - tracing::debug!("bagua-core-py: schedule_comm_post_step"); - self.inner.execute_post_comm_step(bucket_inner, op_inner) - .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) - } - } #[pyclass(dict)] @@ -489,6 +474,7 @@ impl BaguaBucketPy { communicator_intranode: Option<&BaguaSingleCommunicatorPy>, peer_selection_mode: String, torch_stream: u64, + weight: PyRef, diff_tensor: PyRef, ) -> BaguaCommOpPy { BaguaCommOpPy { @@ -497,6 +483,7 @@ impl BaguaBucketPy { communicator_intranode.map(|x| &x.inner), peer_selection_mode, torch_stream, + (*weight).inner.clone(), (*diff_tensor).inner.clone(), ) } From 4ef4081607af4c0503b917b8576a9f9ce8dc108e Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 3 Sep 2021 17:55:02 +0800 Subject: [PATCH 04/13] fmt --- ...centralized_full_precision_asynchronous.rs | 44 ++++++++----------- bagua-core-internal/src/comm_ops/mod.rs | 9 +--- bagua-core-internal/src/datatypes/mod.rs | 18 +++----- bagua-core-internal/src/lib.rs | 1 - bagua-core-py/src/lib.rs | 15 +++---- 5 files changed, 32 insertions(+), 55 deletions(-) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index 4c86cc1..b88ec98 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -1,7 +1,9 @@ use crate::comm_ops::decentralized_full_precision_synchronous::PeerSelectionMode; use crate::comm_ops::CommOpTrait; use crate::communicators::BaguaCommunicator; -use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor, BaguaTensor}; +use crate::datatypes::{ + BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorRaw, RawBaguaTensor, +}; use crate::events::BaguaEventChannel; use crate::resource_pool::{CUDA_DEVICE_MEMORY_POOL, CUDA_EVENT_POOL}; use crate::{BaguaCommOpChannels, BaguaCoreError}; @@ -46,8 +48,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { false, false, &mut |c, t| { - let start_time = std::time::Instant::now(); - tracing::debug!("#{} async model average start", c.rank); + tracing::debug!("async model average start"); let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()] .try_pull(t.raw.num_elements_allocated() * t.raw.dtype().bytes()) @@ -77,7 +78,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { let src_ready_event = CUDA_EVENT_POOL.take().event; let dst_ready_event = CUDA_EVENT_POOL.take().event; - + unsafe { cpp::cpp!([ dst_ready_event as "cudaEvent_t", @@ -103,10 +104,10 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }); } - /* if c.check_abort() { + if c.check_abort() { return; } -*/ + match peer_mode { PeerSelectionMode::All => { c.allreduce(&temp_tensor, &mut reduced_tensor, BaguaReductionOp::SUM); @@ -119,7 +120,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { } }; - { + { let mut tensor_guard = self.diff_tensor.inner.write(); tensor_guard.raw.async_model_average( &reduced_tensor, @@ -129,29 +130,22 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { ); } - tracing::debug!( - "#{} async model average update cost: {:?}", - c.rank, start_time.elapsed() - ); + tracing::debug!("async model average update cost end"); }, ); } - - fn execute_post_step( - &self, - bucket: Arc - ) { - - tracing::debug!("async update weight start"); + + fn execute_post_step(&self, bucket: Arc) { + tracing::debug!("async model average post step start"); let torch_stream = self.torch_stream; - + let mut tensor_guard = self.diff_tensor.inner.write(); let mut guard = self.weight.inner.write(); - - guard.raw.async_model_update(tensor_guard.raw.as_ref(), torch_stream); - - tracing::debug!("async update weight end"); - - } + guard + .raw + .async_model_update(tensor_guard.raw.as_ref(), torch_stream); + + tracing::debug!("async model average post step end"); + } } diff --git a/bagua-core-internal/src/comm_ops/mod.rs b/bagua-core-internal/src/comm_ops/mod.rs index 3779fe1..9598f22 100644 --- a/bagua-core-internal/src/comm_ops/mod.rs +++ b/bagua-core-internal/src/comm_ops/mod.rs @@ -16,11 +16,6 @@ pub trait CommOpTrait: Debug { bucket: Arc, comm_channels: &BaguaCommOpChannels, ); - - fn execute_post_step( - &self, - bucket: Arc, - ) { - } -} + fn execute_post_step(&self, bucket: Arc) {} +} diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index 843651b..2674a4a 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -6,7 +6,7 @@ use crate::comm_ops::decentralized_full_precision_synchronous::{ }; use crate::comm_ops::decentralized_low_precision_synchronous::DecentralizedLowPrecisionSynchronous; use crate::comm_ops::python_ffi_op::PythonFFIOp; -use crate::comm_ops::CommOpTrait; +use crate::comm_ops::CommOpTrait; use crate::communicators::{BaguaCommunicator, BaguaSingleCommunicator}; use crate::resource_pool::{CudaMemory, CUDA_DEVICE_MEMORY_POOL}; use crate::torch_ffi::root::c10::{DeviceType, StorageImpl, TensorImpl}; @@ -165,14 +165,10 @@ pub trait RawBaguaTensor: Debug { } } - fn async_model_update( - &mut self, - diff_tensor: &dyn RawBaguaTensor, - stream_ptr: u64, - ) { + fn async_model_update(&mut self, diff_tensor: &dyn RawBaguaTensor, stream_ptr: u64) { assert_eq!(self.dtype(), diff_tensor.dtype()); assert_eq!(self.num_elements(), diff_tensor.num_elements()); - + let tensor_ptr = self.data_ptr(); let total_num_elem = self.num_elements(); unsafe { @@ -1093,15 +1089,11 @@ impl<'b> Drop for BaguaCommunicationTensor<'b> { #[derive(Debug, Clone)] pub struct BaguaCommOp { pub name: String, - pub inner: Arc + pub inner: Arc, } impl BaguaCommOp { - pub fn execute_post_step( - &self, - bucket: &BaguaBucket, - ) -> Result<(), BaguaCoreError> { - + pub fn execute_post_step(&self, bucket: &BaguaBucket) -> Result<(), BaguaCoreError> { let bucket = Arc::new((*bucket).clone()); self.inner.execute_post_step(bucket); diff --git a/bagua-core-internal/src/lib.rs b/bagua-core-internal/src/lib.rs index 786d349..caead44 100644 --- a/bagua-core-internal/src/lib.rs +++ b/bagua-core-internal/src/lib.rs @@ -335,5 +335,4 @@ impl BaguaCommBackend { } } } - } diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index 66dd006..098c694 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -2,7 +2,7 @@ use bagua_core_internal::communicators::BaguaSingleCommunicator; use bagua_core_internal::datatypes::{ - BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorDtype, BaguaCommOp + BaguaBucket, BaguaCommOp, BaguaReductionOp, BaguaTensor, BaguaTensorDtype, }; use bagua_core_internal::BaguaCommBackend; use num_traits::FromPrimitive; @@ -195,19 +195,17 @@ impl BaguaSingleCommunicatorPy { #[pyclass(dict)] pub struct BaguaCommOpPy { - inner: BaguaCommOp + inner: BaguaCommOp, } #[pymethods] impl BaguaCommOpPy { - - pub fn execute_post_step(&self, bucket: PyRef) -> PyResult<()> { + pub fn execute_post_step(&self, bucket: PyRef) -> PyResult<()> { let bucket_inner = &bucket.inner; - self.inner.execute_post_step(bucket_inner) + self.inner + .execute_post_step(bucket_inner) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) - } - } #[pyclass(dict)] @@ -364,7 +362,6 @@ impl BaguaCommBackendPy { py.allow_threads(|| self.inner.wait_pending_comm_ops()) .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) } - } #[pyclass(dict)] @@ -485,7 +482,7 @@ impl BaguaBucketPy { torch_stream, (*weight).inner.clone(), (*diff_tensor).inner.clone(), - ) + ), } } From 64476415ae0540e7942963465a9d0fd9518ceba2 Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 6 Sep 2021 16:20:10 +0800 Subject: [PATCH 05/13] remove nan print --- bagua-core-internal/kernels/bagua_kernels.cu | 4 ++-- .../src/comm_ops/decentralized_full_precision_asynchronous.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bagua-core-internal/kernels/bagua_kernels.cu b/bagua-core-internal/kernels/bagua_kernels.cu index 34eaea6..513b857 100644 --- a/bagua-core-internal/kernels/bagua_kernels.cu +++ b/bagua-core-internal/kernels/bagua_kernels.cu @@ -259,9 +259,9 @@ __global__ void async_model_average(float *tensor, const float *reduced_tensor_c for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { tensor[i] += reduced_tensor_copy[i] / nranks - tensor_copy[i]; - if (tensor[i] != tensor[i]) { + /*if (tensor[i] != tensor[i]) { printf("nan encountered!"); - } + }*/ // atomicAdd(&tensor[i], reduced_tensor_copy[i] / nranks - tensor_copy[i]); } } diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index b88ec98..ead1008 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -130,7 +130,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { ); } - tracing::debug!("async model average update cost end"); + tracing::debug!("async model average update end"); }, ); } From 961c746fa9331b7e29354c94bff24312a570559a Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 7 Sep 2021 11:04:29 +0800 Subject: [PATCH 06/13] for mutex --- ...centralized_full_precision_asynchronous.rs | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index ead1008..d0e0cad 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -48,6 +48,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { false, false, &mut |c, t| { + let start_time = std::time::Instant::now(); tracing::debug!("async model average start"); let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()] @@ -128,9 +129,23 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { c.nranks as f32, comm_stream, ); - } - tracing::debug!("async model average update end"); + let ready_event = CUDA_EVENT_POOL.take().event; + + unsafe { + cpp::cpp!([ + ready_event as "cudaEvent_t", + comm_stream as "cudaStream_t"] + { + CUDACHECK(cudaEventRecord(ready_event, comm_stream)); + CUDACHECK(cudaEventSynchronize(ready_event)); + }); + } + } + tracing::debug!( + "async model average update cost: {:?}", + start_time.elapsed() + ); }, ); } @@ -138,14 +153,26 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { fn execute_post_step(&self, bucket: Arc) { tracing::debug!("async model average post step start"); let torch_stream = self.torch_stream; + { + let mut tensor_guard = self.diff_tensor.inner.write(); + let mut guard = self.weight.inner.write(); - let mut tensor_guard = self.diff_tensor.inner.write(); - let mut guard = self.weight.inner.write(); + guard + .raw + .async_model_update(tensor_guard.raw.as_ref(), torch_stream); - guard - .raw - .async_model_update(tensor_guard.raw.as_ref(), torch_stream); + let ready_event = CUDA_EVENT_POOL.take().event; + unsafe { + cpp::cpp!([ + ready_event as "cudaEvent_t", + torch_stream as "cudaStream_t"] + { + CUDACHECK(cudaEventRecord(ready_event, torch_stream)); + CUDACHECK(cudaEventSynchronize(ready_event)); + }); + } + } tracing::debug!("async model average post step end"); } } From 6400066260a47ea66ba5ed5958ddef3deaac8bdf Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 7 Sep 2021 12:01:42 +0800 Subject: [PATCH 07/13] avoid potential deadlock --- .../decentralized_full_precision_asynchronous.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index d0e0cad..0633831 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -121,6 +121,22 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { } }; + let comm_ready_event = CUDA_EVENT_POOL.take().event; + + unsafe { + cpp::cpp!([ + comm_ready_event as "cudaEvent_t", + comm_stream as "cudaStream_t"] + { + CUDACHECK(cudaEventRecord(comm_ready_event, comm_stream)); + CUDACHECK(cudaEventSynchronize(comm_ready_event)); + }); + } + + if c.check_abort() { + return; + } + { let mut tensor_guard = self.diff_tensor.inner.write(); tensor_guard.raw.async_model_average( From cec1340e2ece9189d3ebd36d8ca439f641b34eeb Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 7 Sep 2021 12:09:28 +0800 Subject: [PATCH 08/13] remove unused --- .../decentralized_full_precision_asynchronous.rs | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index 0633831..20da046 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -78,18 +78,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }; let src_ready_event = CUDA_EVENT_POOL.take().event; - let dst_ready_event = CUDA_EVENT_POOL.take().event; - - unsafe { - cpp::cpp!([ - dst_ready_event as "cudaEvent_t", - comm_stream as "cudaStream_t", - torch_stream as "cudaStream_t"] - { - CUDACHECK(cudaEventRecord(dst_ready_event, comm_stream)); - CUDACHECK(cudaStreamWaitEvent(torch_stream, dst_ready_event , 0)); - }); - } // use default stream to copy weights temp_tensor.clone_from(&t.raw, torch_stream as u64); From 612764caefd5557fcd73ea3de6e862b404c88e36 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 9 Sep 2021 21:28:50 +0800 Subject: [PATCH 09/13] keep weights fresh new --- .../comm_ops/decentralized_full_precision_asynchronous.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index 20da046..bc7085b 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -97,6 +97,11 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { return; } + { + let tensor_guard = self.diff_tensor.inner.read(); + temp_tensor.add_inplace(tensor_guard.raw.as_ref(), comm_stream as u64); + } + match peer_mode { PeerSelectionMode::All => { c.allreduce(&temp_tensor, &mut reduced_tensor, BaguaReductionOp::SUM); From 82c53dc4e69071c913a6afdd3a7bc5dbf1e4a3b8 Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 13 Sep 2021 16:15:24 +0800 Subject: [PATCH 10/13] interleave model average and model update --- ...centralized_full_precision_asynchronous.rs | 50 +++++++++---------- bagua-core-internal/src/datatypes/mod.rs | 2 + 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index bc7085b..bfced41 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -7,6 +7,7 @@ use crate::datatypes::{ use crate::events::BaguaEventChannel; use crate::resource_pool::{CUDA_DEVICE_MEMORY_POOL, CUDA_EVENT_POOL}; use crate::{BaguaCommOpChannels, BaguaCoreError}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -17,6 +18,7 @@ pub struct DecentralizedFullPrecisionAsynchronous { pub torch_stream: u64, pub weight: BaguaTensor, pub diff_tensor: BaguaTensor, + pub has_updated: Arc, } impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { @@ -25,6 +27,12 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { bucket: Arc, comm_op_channels: &BaguaCommOpChannels, ) { + // Wait until diff tensor is applied + let has_updated = self.has_updated.load(Ordering::Acquire); + if has_updated { + return; + } + let bucket_guard = bucket.inner.lock(); let comm_stream = self.communicator.stream_ptr(); @@ -49,7 +57,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { false, &mut |c, t| { let start_time = std::time::Instant::now(); - tracing::debug!("async model average start"); + tracing::debug!("#{} async model average start", c.rank); let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()] .try_pull(t.raw.num_elements_allocated() * t.raw.dtype().bytes()) @@ -93,15 +101,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }); } - if c.check_abort() { - return; - } - - { - let tensor_guard = self.diff_tensor.inner.read(); - temp_tensor.add_inplace(tensor_guard.raw.as_ref(), comm_stream as u64); - } - match peer_mode { PeerSelectionMode::All => { c.allreduce(&temp_tensor, &mut reduced_tensor, BaguaReductionOp::SUM); @@ -114,22 +113,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { } }; - let comm_ready_event = CUDA_EVENT_POOL.take().event; - - unsafe { - cpp::cpp!([ - comm_ready_event as "cudaEvent_t", - comm_stream as "cudaStream_t"] - { - CUDACHECK(cudaEventRecord(comm_ready_event, comm_stream)); - CUDACHECK(cudaEventSynchronize(comm_ready_event)); - }); - } - - if c.check_abort() { - return; - } - { let mut tensor_guard = self.diff_tensor.inner.write(); tensor_guard.raw.async_model_average( @@ -151,8 +134,12 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }); } } + + self.has_updated.store(true, Ordering::Release); + tracing::debug!( - "async model average update cost: {:?}", + "#{} async model average update cost: {:?}", + c.rank, start_time.elapsed() ); }, @@ -160,10 +147,17 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { } fn execute_post_step(&self, bucket: Arc) { + let has_updated = self.has_updated.load(Ordering::Acquire); + if !has_updated { + return; + } + tracing::debug!("async model average post step start"); let torch_stream = self.torch_stream; { let mut tensor_guard = self.diff_tensor.inner.write(); + + // We must patch diff tensor to model weights here let mut guard = self.weight.inner.write(); guard @@ -182,6 +176,8 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }); } } + + self.has_updated.store(false, Ordering::Release); tracing::debug!("async model average post step end"); } } diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index 2674a4a..b62b727 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -19,6 +19,7 @@ use pyo3::types::IntoPyDict; use sized_object_pool::DynamicPoolItem; use std::ffi::c_void; use std::fmt::Debug; +use std::sync::atomic::AtomicBool; use std::sync::Arc; // must be consistent with Aluminum ReductionOperator: https://github.com/BaguaSys/Aluminum/blob/master/include/aluminum/base.hpp @@ -1283,6 +1284,7 @@ impl BaguaBucket { torch_stream, weight, diff_tensor, + has_updated: Arc::new(AtomicBool::new(false)), }, ); From 8932417cf21db09d7eed95f2c761f92beb375c99 Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 14 Sep 2021 16:40:40 +0800 Subject: [PATCH 11/13] use mutex instead --- ...centralized_full_precision_asynchronous.rs | 62 +++++-------------- bagua-core-internal/src/comm_ops/mod.rs | 2 - bagua-core-internal/src/datatypes/mod.rs | 53 +++++----------- bagua-core-py/src/lib.rs | 31 +++++----- 4 files changed, 47 insertions(+), 101 deletions(-) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index bfced41..5b72419 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -7,7 +7,7 @@ use crate::datatypes::{ use crate::events::BaguaEventChannel; use crate::resource_pool::{CUDA_DEVICE_MEMORY_POOL, CUDA_EVENT_POOL}; use crate::{BaguaCommOpChannels, BaguaCoreError}; -use std::sync::atomic::{AtomicBool, Ordering}; +use parking_lot::{lock_api::RawMutex as _, Mutex, RawMutex}; use std::sync::Arc; use std::time::Duration; @@ -16,9 +16,7 @@ pub struct DecentralizedFullPrecisionAsynchronous { pub communicator: BaguaCommunicator, pub peer_selection_mode: PeerSelectionMode, pub torch_stream: u64, - pub weight: BaguaTensor, - pub diff_tensor: BaguaTensor, - pub has_updated: Arc, + pub weight_mutex: Arc>, } impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { @@ -27,12 +25,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { bucket: Arc, comm_op_channels: &BaguaCommOpChannels, ) { - // Wait until diff tensor is applied - let has_updated = self.has_updated.load(Ordering::Acquire); - if has_updated { - return; - } - let bucket_guard = bucket.inner.lock(); let comm_stream = self.communicator.stream_ptr(); @@ -114,8 +106,8 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }; { - let mut tensor_guard = self.diff_tensor.inner.write(); - tensor_guard.raw.async_model_average( + self.lock_weight(); + t.raw.async_model_average( &reduced_tensor, &temp_tensor, c.nranks as f32, @@ -133,10 +125,9 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { CUDACHECK(cudaEventSynchronize(ready_event)); }); } + self.unlock_weight(); } - self.has_updated.store(true, Ordering::Release); - tracing::debug!( "#{} async model average update cost: {:?}", c.rank, @@ -145,39 +136,18 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }, ); } +} - fn execute_post_step(&self, bucket: Arc) { - let has_updated = self.has_updated.load(Ordering::Acquire); - if !has_updated { - return; - } - - tracing::debug!("async model average post step start"); - let torch_stream = self.torch_stream; - { - let mut tensor_guard = self.diff_tensor.inner.write(); - - // We must patch diff tensor to model weights here - let mut guard = self.weight.inner.write(); - - guard - .raw - .async_model_update(tensor_guard.raw.as_ref(), torch_stream); - - let ready_event = CUDA_EVENT_POOL.take().event; - - unsafe { - cpp::cpp!([ - ready_event as "cudaEvent_t", - torch_stream as "cudaStream_t"] - { - CUDACHECK(cudaEventRecord(ready_event, torch_stream)); - CUDACHECK(cudaEventSynchronize(ready_event)); - }); - } - } +impl DecentralizedFullPrecisionAsynchronous { + pub fn lock_weight(&self) { + let raw_mutex = unsafe { self.weight_mutex.raw() }; + raw_mutex.lock(); + } - self.has_updated.store(false, Ordering::Release); - tracing::debug!("async model average post step end"); + pub fn unlock_weight(&self) { + unsafe { + let raw_mutex = self.weight_mutex.raw(); + raw_mutex.unlock(); + }; } } diff --git a/bagua-core-internal/src/comm_ops/mod.rs b/bagua-core-internal/src/comm_ops/mod.rs index 9598f22..f22aa73 100644 --- a/bagua-core-internal/src/comm_ops/mod.rs +++ b/bagua-core-internal/src/comm_ops/mod.rs @@ -16,6 +16,4 @@ pub trait CommOpTrait: Debug { bucket: Arc, comm_channels: &BaguaCommOpChannels, ); - - fn execute_post_step(&self, bucket: Arc) {} } diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index b62b727..a4c182a 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -1087,21 +1087,6 @@ impl<'b> Drop for BaguaCommunicationTensor<'b> { } } -#[derive(Debug, Clone)] -pub struct BaguaCommOp { - pub name: String, - pub inner: Arc, -} - -impl BaguaCommOp { - pub fn execute_post_step(&self, bucket: &BaguaBucket) -> Result<(), BaguaCoreError> { - let bucket = Arc::new((*bucket).clone()); - self.inner.execute_post_step(bucket); - - Ok(()) - } -} - #[derive(Debug, Clone)] pub struct BaguaBucket { pub name: String, @@ -1265,35 +1250,29 @@ impl BaguaBucket { communicator_intranode: Option<&BaguaSingleCommunicator>, peer_selection_mode: String, torch_stream: u64, - weight: BaguaTensor, - diff_tensor: BaguaTensor, - ) -> BaguaCommOp { + ) -> Arc { let communicator = BaguaCommunicator::new(communicator_internode, communicator_intranode, false) .expect("cannot create communicator"); - let comm_op: Arc = Arc::new( - DecentralizedFullPrecisionAsynchronous { - communicator, - peer_selection_mode: match peer_selection_mode.as_str() { - "all" => PeerSelectionMode::All, - &_ => { - unimplemented!("unsupported peer_selection_mode for decentralized asynchronous algorithm (should be `all`)") - } - }, - torch_stream, - weight, - diff_tensor, - has_updated: Arc::new(AtomicBool::new(false)), + let comm_op = Arc::new(DecentralizedFullPrecisionAsynchronous { + communicator, + peer_selection_mode: match peer_selection_mode.as_str() { + "all" => PeerSelectionMode::All, + &_ => { + unimplemented!("unsupported peer_selection_mode for decentralized asynchronous algorithm (should be `all`)") + } }, - ); + torch_stream, + weight_mutex: Arc::new(Mutex::new(true)), + }); - self.inner.lock().comm_ops.push(comm_op.clone()); + self.inner + .lock() + .comm_ops + .push(comm_op.clone() as Arc); - BaguaCommOp { - name: String::from("decentralized_async_op"), - inner: comm_op, - } + comm_op } pub fn ready_for_comm(&self) -> bool { diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index 098c694..ffbf343 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -1,8 +1,9 @@ #![allow(clippy::needless_return)] +use bagua_core_internal::comm_ops::decentralized_full_precision_asynchronous::DecentralizedFullPrecisionAsynchronous; use bagua_core_internal::communicators::BaguaSingleCommunicator; use bagua_core_internal::datatypes::{ - BaguaBucket, BaguaCommOp, BaguaReductionOp, BaguaTensor, BaguaTensorDtype, + BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorDtype, }; use bagua_core_internal::BaguaCommBackend; use num_traits::FromPrimitive; @@ -10,6 +11,7 @@ use numpy::{IntoPyArray, PyArray1}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::PyNativeType; +use std::sync::Arc; #[pyclass(dict)] pub struct BaguaSingleCommunicatorPy { @@ -194,17 +196,18 @@ impl BaguaSingleCommunicatorPy { } #[pyclass(dict)] -pub struct BaguaCommOpPy { - inner: BaguaCommOp, +pub struct DecentralizedFullPrecisionAsynchronousPy { + inner: Arc, } #[pymethods] -impl BaguaCommOpPy { - pub fn execute_post_step(&self, bucket: PyRef) -> PyResult<()> { - let bucket_inner = &bucket.inner; - self.inner - .execute_post_step(bucket_inner) - .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) +impl DecentralizedFullPrecisionAsynchronousPy { + pub fn lock_weight(&self) { + self.inner.lock_weight() + } + + pub fn unlock_weight(&self) { + self.inner.unlock_weight() } } @@ -471,17 +474,13 @@ impl BaguaBucketPy { communicator_intranode: Option<&BaguaSingleCommunicatorPy>, peer_selection_mode: String, torch_stream: u64, - weight: PyRef, - diff_tensor: PyRef, - ) -> BaguaCommOpPy { - BaguaCommOpPy { + ) -> DecentralizedFullPrecisionAsynchronousPy { + DecentralizedFullPrecisionAsynchronousPy { inner: self.inner.append_decentralized_asynchronous_op( communicator_internode.map(|x| &x.inner), communicator_intranode.map(|x| &x.inner), peer_selection_mode, torch_stream, - (*weight).inner.clone(), - (*diff_tensor).inner.clone(), ), } } @@ -527,7 +526,7 @@ fn bagua_core(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; #[pyfn(m, "show_version")] fn show_version(_py: Python) { From 40899436410eecf1b3fa6f535a2c6ee29c587b00 Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 14 Sep 2021 17:01:46 +0800 Subject: [PATCH 12/13] add --- .../decentralized_full_precision_asynchronous.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index 5b72419..8049bdc 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -106,6 +106,17 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { }; { + let ready_event = CUDA_EVENT_POOL.take().event; + unsafe { + cpp::cpp!([ + ready_event as "cudaEvent_t", + comm_stream as "cudaStream_t"] + { + CUDACHECK(cudaEventRecord(ready_event, comm_stream)); + CUDACHECK(cudaEventSynchronize(ready_event)); + }); + } + self.lock_weight(); t.raw.async_model_average( &reduced_tensor, @@ -114,8 +125,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { comm_stream, ); - let ready_event = CUDA_EVENT_POOL.take().event; - unsafe { cpp::cpp!([ ready_event as "cudaEvent_t", From 970d6c17d083770c9e4bd80c66b0641c703d361e Mon Sep 17 00:00:00 2001 From: rita Date: Tue, 14 Sep 2021 17:14:55 +0800 Subject: [PATCH 13/13] remove unused --- bagua-core-internal/kernels/bagua_kernels.cu | 12 ------------ .../decentralized_full_precision_asynchronous.rs | 2 +- bagua-core-internal/src/datatypes/mod.rs | 16 ---------------- bagua-core-internal/src/kernels/mod.rs | 6 ------ 4 files changed, 1 insertion(+), 35 deletions(-) diff --git a/bagua-core-internal/kernels/bagua_kernels.cu b/bagua-core-internal/kernels/bagua_kernels.cu index 513b857..195246e 100644 --- a/bagua-core-internal/kernels/bagua_kernels.cu +++ b/bagua-core-internal/kernels/bagua_kernels.cu @@ -266,13 +266,6 @@ __global__ void async_model_average(float *tensor, const float *reduced_tensor_c } } -__global__ void async_model_update(float *tensor, float *diff_tensor, const int N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - tensor[i] += diff_tensor[i]; - diff_tensor[i] = 0.0; - } -} - template size_t array_min_max_size( const T *input_array, @@ -635,11 +628,6 @@ void async_model_average_host(float *tensor, const float *reduced_tensor_copy, CUDACHECK(cudaGetLastError()); } -void async_model_update_host(float *tensor, float *diff_tensor, const int N, cudaStream_t stream) { - async_model_update<<>>(tensor, diff_tensor, N); - CUDACHECK(cudaGetLastError()); -} - //// decentralize, recvbuf should get the average of sendbuf and peer's sendbuf //ncclResult_t ncclPeerAverage(void *sendbuf, void *recvbuf, size_t sendcount, // int peer_rank, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { diff --git a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index 8049bdc..da6cc5c 100644 --- a/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -116,7 +116,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { CUDACHECK(cudaEventSynchronize(ready_event)); }); } - + self.lock_weight(); t.raw.async_model_average( &reduced_tensor, diff --git a/bagua-core-internal/src/datatypes/mod.rs b/bagua-core-internal/src/datatypes/mod.rs index a4c182a..8319d61 100644 --- a/bagua-core-internal/src/datatypes/mod.rs +++ b/bagua-core-internal/src/datatypes/mod.rs @@ -166,22 +166,6 @@ pub trait RawBaguaTensor: Debug { } } - fn async_model_update(&mut self, diff_tensor: &dyn RawBaguaTensor, stream_ptr: u64) { - assert_eq!(self.dtype(), diff_tensor.dtype()); - assert_eq!(self.num_elements(), diff_tensor.num_elements()); - - let tensor_ptr = self.data_ptr(); - let total_num_elem = self.num_elements(); - unsafe { - kernels::async_model_update_host( - tensor_ptr as _, - diff_tensor.data_ptr() as _, - total_num_elem as i32, - stream_ptr as _, - ); - } - } - fn substract_inplace(&mut self, other: &dyn RawBaguaTensor, stream_ptr: u64) { assert_eq!(self.dtype(), other.dtype()); assert_eq!(self.num_elements(), other.num_elements()); diff --git a/bagua-core-internal/src/kernels/mod.rs b/bagua-core-internal/src/kernels/mod.rs index 3fc6ad3..55df19e 100644 --- a/bagua-core-internal/src/kernels/mod.rs +++ b/bagua-core-internal/src/kernels/mod.rs @@ -134,10 +134,4 @@ extern "C" { N: i32, stream: *const c_void, ); - pub fn async_model_update_host( - tensor: *mut c_void, - diff_tensor: *mut c_void, - N: i32, - stream: *const c_void, - ); }