Skip to content
This repository was archived by the owner on Sep 15, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions bagua-core-internal/kernels/bagua_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ __global__ void divide_inplace_f16(__half *x, float D_, int N) {
__global__ void async_model_average(float *tensor, const float *reduced_tensor_copy,
const float *tensor_copy, const float nranks, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
// tensor[i] += reduced_tensor_copy[i] / nranks - tensor_copy[i];
// if (tensor[i] != tensor[i]) {
// printf("nan encountered!");
// }
atomicAdd(&tensor[i], reduced_tensor_copy[i] / nranks - tensor_copy[i]);

tensor[i] += reduced_tensor_copy[i] / nranks - tensor_copy[i];
/*if (tensor[i] != tensor[i]) {
printf("nan encountered!");
}*/
// atomicAdd(&tensor[i], reduced_tensor_copy[i] / nranks - tensor_copy[i]);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::comm_ops::decentralized_full_precision_synchronous::PeerSelectionMode;
use crate::comm_ops::CommOpTrait;
use crate::communicators::BaguaCommunicator;
use crate::datatypes::{BaguaBucket, BaguaReductionOp, BaguaTensorRaw, RawBaguaTensor};
use crate::datatypes::{
BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorRaw, RawBaguaTensor,
};
use crate::events::BaguaEventChannel;
use crate::resource_pool::{CUDA_DEVICE_MEMORY_POOL, CUDA_EVENT_POOL};
use crate::{BaguaCommOpChannels, BaguaCoreError};
use parking_lot::{lock_api::RawMutex as _, Mutex, RawMutex};
use std::sync::Arc;
use std::time::Duration;

Expand All @@ -13,6 +16,7 @@ pub struct DecentralizedFullPrecisionAsynchronous {
pub communicator: BaguaCommunicator,
pub peer_selection_mode: PeerSelectionMode,
pub torch_stream: u64,
pub weight_mutex: Arc<Mutex<bool>>,
}

impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
Expand All @@ -22,6 +26,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
comm_op_channels: &BaguaCommOpChannels,
) {
let bucket_guard = bucket.inner.lock();

let comm_stream = self.communicator.stream_ptr();

let mut communication_tensor = match &self.communicator {
Expand All @@ -44,7 +49,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
false,
&mut |c, t| {
let start_time = std::time::Instant::now();
tracing::debug!("async model average start");
tracing::debug!("#{} async model average start", c.rank);

let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id()]
.try_pull(t.raw.num_elements_allocated() * t.raw.dtype().bytes())
Expand Down Expand Up @@ -72,11 +77,11 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
pool_allocations: vec![Arc::new(reduced_buf)],
};

let src_ready_event = CUDA_EVENT_POOL.take().event;

// use default stream to copy weights
temp_tensor.clone_from(&t.raw, torch_stream as u64);

let src_ready_event = CUDA_EVENT_POOL.take().event;

unsafe {
cpp::cpp!([
src_ready_event as "cudaEvent_t",
Expand All @@ -88,10 +93,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
});
}

if c.check_abort() {
return;
}

match peer_mode {
PeerSelectionMode::All => {
c.allreduce(&temp_tensor, &mut reduced_tensor, BaguaReductionOp::SUM);
Expand All @@ -104,53 +105,58 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
}
};

let comm_ready_event = CUDA_EVENT_POOL.take().event;

unsafe {
cpp::cpp!([
comm_ready_event as "cudaEvent_t",
comm_stream as "cudaStream_t"]
{
CUDACHECK(cudaEventRecord(comm_ready_event, comm_stream));
CUDACHECK(cudaEventSynchronize(comm_ready_event));
});
}

if c.check_abort() {
return;
}

// do we need to wait default stream?
unsafe {
cpp::cpp!([
src_ready_event as "cudaEvent_t",
comm_stream as "cudaStream_t",
torch_stream as "cudaStream_t"]
{
CUDACHECK(cudaEventRecord(src_ready_event, torch_stream));
CUDACHECK(cudaStreamWaitEvent(comm_stream, src_ready_event , 0));
});
}

t.raw.async_model_average(
&reduced_tensor,
&temp_tensor,
c.nranks as f32,
comm_stream,
);
{
let ready_event = CUDA_EVENT_POOL.take().event;
unsafe {
cpp::cpp!([
ready_event as "cudaEvent_t",
comm_stream as "cudaStream_t"]
{
CUDACHECK(cudaEventRecord(ready_event, comm_stream));
CUDACHECK(cudaEventSynchronize(ready_event));
});
}

unsafe {
cpp::cpp!([comm_stream as "cudaStream_t"]
{
CUDACHECK(cudaStreamSynchronize(comm_stream));
});
self.lock_weight();
t.raw.async_model_average(
&reduced_tensor,
&temp_tensor,
c.nranks as f32,
comm_stream,
);

unsafe {
cpp::cpp!([
ready_event as "cudaEvent_t",
comm_stream as "cudaStream_t"]
{
CUDACHECK(cudaEventRecord(ready_event, comm_stream));
CUDACHECK(cudaEventSynchronize(ready_event));
});
}
self.unlock_weight();
}

tracing::debug!(
"async model average update cost: {:?}",
"#{} async model average update cost: {:?}",
c.rank,
start_time.elapsed()
);
},
);
}
}

impl DecentralizedFullPrecisionAsynchronous {
pub fn lock_weight(&self) {
let raw_mutex = unsafe { self.weight_mutex.raw() };
raw_mutex.lock();
}

pub fn unlock_weight(&self) {
unsafe {
let raw_mutex = self.weight_mutex.raw();
raw_mutex.unlock();
};
}
}
32 changes: 19 additions & 13 deletions bagua-core-internal/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use pyo3::types::IntoPyDict;
use sized_object_pool::DynamicPoolItem;
use std::ffi::c_void;
use std::fmt::Debug;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;

// must be consistent with Aluminum ReductionOperator: https://github.com/BaguaSys/Aluminum/blob/master/include/aluminum/base.hpp
Expand Down Expand Up @@ -164,6 +165,7 @@ pub trait RawBaguaTensor: Debug {
);
}
}

fn substract_inplace(&mut self, other: &dyn RawBaguaTensor, stream_ptr: u64) {
assert_eq!(self.dtype(), other.dtype());
assert_eq!(self.num_elements(), other.num_elements());
Expand Down Expand Up @@ -1232,25 +1234,29 @@ impl BaguaBucket {
communicator_intranode: Option<&BaguaSingleCommunicator>,
peer_selection_mode: String,
torch_stream: u64,
) {
) -> Arc<DecentralizedFullPrecisionAsynchronous> {
let communicator =
BaguaCommunicator::new(communicator_internode, communicator_intranode, false)
.expect("cannot create communicator");

let comm_op: Arc<dyn CommOpTrait + Send + Sync> = Arc::new(
DecentralizedFullPrecisionAsynchronous {
communicator,
peer_selection_mode: match peer_selection_mode.as_str() {
"all" => PeerSelectionMode::All,
&_ => {
unimplemented!("unsupported peer_selection_mode for decentralized asynchronous algorithm (should be `all`)")
}
},
torch_stream,
let comm_op = Arc::new(DecentralizedFullPrecisionAsynchronous {
communicator,
peer_selection_mode: match peer_selection_mode.as_str() {
"all" => PeerSelectionMode::All,
&_ => {
unimplemented!("unsupported peer_selection_mode for decentralized asynchronous algorithm (should be `all`)")
}
},
);
torch_stream,
weight_mutex: Arc::new(Mutex::new(true)),
});

self.inner.lock().comm_ops.push(comm_op);
self.inner
.lock()
.comm_ops
.push(comm_op.clone() as Arc<dyn CommOpTrait + Send + Sync>);

comm_op
}

pub fn ready_for_comm(&self) -> bool {
Expand Down
36 changes: 28 additions & 8 deletions bagua-core-py/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(clippy::needless_return)]

use bagua_core_internal::comm_ops::decentralized_full_precision_asynchronous::DecentralizedFullPrecisionAsynchronous;
use bagua_core_internal::communicators::BaguaSingleCommunicator;
use bagua_core_internal::datatypes::{
BaguaBucket, BaguaReductionOp, BaguaTensor, BaguaTensorDtype,
Expand All @@ -10,6 +11,7 @@ use numpy::{IntoPyArray, PyArray1};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::PyNativeType;
use std::sync::Arc;

#[pyclass(dict)]
pub struct BaguaSingleCommunicatorPy {
Expand Down Expand Up @@ -193,6 +195,22 @@ impl BaguaSingleCommunicatorPy {
}
}

#[pyclass(dict)]
pub struct DecentralizedFullPrecisionAsynchronousPy {
inner: Arc<DecentralizedFullPrecisionAsynchronous>,
}

#[pymethods]
impl DecentralizedFullPrecisionAsynchronousPy {
pub fn lock_weight(&self) {
self.inner.lock_weight()
}

pub fn unlock_weight(&self) {
self.inner.unlock_weight()
}
}

#[pyclass(dict)]
pub struct BaguaTensorPy {
inner: BaguaTensor,
Expand Down Expand Up @@ -456,14 +474,15 @@ impl BaguaBucketPy {
communicator_intranode: Option<&BaguaSingleCommunicatorPy>,
peer_selection_mode: String,
torch_stream: u64,
) -> PyResult<()> {
self.inner.append_decentralized_asynchronous_op(
communicator_internode.map(|x| &x.inner),
communicator_intranode.map(|x| &x.inner),
peer_selection_mode,
torch_stream,
);
Ok(())
) -> DecentralizedFullPrecisionAsynchronousPy {
DecentralizedFullPrecisionAsynchronousPy {
inner: self.inner.append_decentralized_asynchronous_op(
communicator_internode.map(|x| &x.inner),
communicator_intranode.map(|x| &x.inner),
peer_selection_mode,
torch_stream,
),
}
}

pub fn print_ops(&self) -> PyResult<()> {
Expand Down Expand Up @@ -507,6 +526,7 @@ fn bagua_core(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<BaguaTensorPy>()?;
m.add_class::<BaguaBucketPy>()?;
m.add_class::<BaguaSingleCommunicatorPy>()?;
m.add_class::<DecentralizedFullPrecisionAsynchronousPy>()?;

#[pyfn(m, "show_version")]
fn show_version(_py: Python) {
Expand Down