From 7e8dda19737921b0b828116451192b2b2548caf5 Mon Sep 17 00:00:00 2001 From: Dennis van der Staay Date: Tue, 29 Jul 2025 22:45:30 -0700 Subject: [PATCH] RDMAXCEL - extended, accelerated - WQE. CQE and DoorBells (#541) Summary: This is basic cuda logic to support writing/reading basic data structures used for ibverbs comms. Effectively this same as ibv_send/ibv_recv but written so that is can executed on cpu or cuda. Reviewed By: allenwang28 Differential Revision: D78362879 --- Cargo.toml | 4 +- cuda-sys/build.rs | 50 ++- monarch_rdma/Cargo.toml | 2 +- monarch_rdma/src/ibverbs_primitives.rs | 96 ++-- monarch_rdma/src/rdma_components.rs | 186 ++++---- monarch_rdma/src/rdma_manager_actor.rs | 2 +- rdmacore-sys/build.rs | 69 --- {rdmacore-sys => rdmaxcel-sys}/Cargo.toml | 8 +- rdmaxcel-sys/build.rs | 244 ++++++++++ {rdmacore-sys => rdmaxcel-sys}/src/lib.rs | 0 rdmaxcel-sys/src/rdmaxcel.c | 207 +++++++++ rdmaxcel-sys/src/rdmaxcel.cu | 420 ++++++++++++++++++ rdmaxcel-sys/src/rdmaxcel.h | 104 +++++ .../src/test_rdmaxcel.c | 10 +- 14 files changed, 1169 insertions(+), 233 deletions(-) delete mode 100644 rdmacore-sys/build.rs rename {rdmacore-sys => rdmaxcel-sys}/Cargo.toml (68%) create mode 100644 rdmaxcel-sys/build.rs rename {rdmacore-sys => rdmaxcel-sys}/src/lib.rs (100%) create mode 100644 rdmaxcel-sys/src/rdmaxcel.c create mode 100644 rdmaxcel-sys/src/rdmaxcel.cu create mode 100644 rdmaxcel-sys/src/rdmaxcel.h rename rdmacore-sys/src/wrapper.h => rdmaxcel-sys/src/test_rdmaxcel.c (57%) diff --git a/Cargo.toml b/Cargo.toml index f2f721ee7..902b5308b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,6 @@ members = [ "monarch_tensor_worker", "monarch_rdma", "nccl-sys", - "rdmacore-sys", + "rdmaxcel-sys", "torch-sys", - "rdmacore-sys", - "cuda-sys", ] diff --git a/cuda-sys/build.rs b/cuda-sys/build.rs index 5197972ba..f9da866fc 100644 --- a/cuda-sys/build.rs +++ b/cuda-sys/build.rs @@ -74,7 +74,7 @@ fn emit_cuda_link_directives(cuda_home: &str) { } fn python_env_dirs() -> (Option, Option) { - let output = std::process::Command::new(PathBuf::from("python")) + let output = std::process::Command::new(PathBuf::from("python3")) .arg("-c") .arg(PYTHON_PRINT_DIRS) .output() @@ -94,13 +94,13 @@ fn python_env_dirs() -> (Option, Option) { } fn main() { + // Start building the bindgen configuration let mut builder = bindgen::Builder::default() // The input header we would like to generate bindings for .header("src/wrapper.h") .clang_arg("-x") .clang_arg("c++") .clang_arg("-std=gnu++20") - .clang_arg(format!("-I{}/include", find_cuda_home().unwrap())) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) // Allow the specified functions and types .allowlist_function("cu.*") @@ -113,6 +113,21 @@ fn main() { is_global: false, }); + // Add CUDA include path if available + if let Some(cuda_home) = find_cuda_home() { + let cuda_include_path = format!("{}/include", cuda_home); + if Path::new(&cuda_include_path).exists() { + builder = builder.clang_arg(format!("-I{}", cuda_include_path)); + } else { + eprintln!( + "Warning: CUDA include directory not found at {}", + cuda_include_path + ); + } + } else { + eprintln!("Warning: CUDA home directory not found. Continuing without CUDA include path."); + } + // Include headers and libs from the active environment. let (include_dir, lib_dir) = python_env_dirs(); if let Some(include_dir) = include_dir { @@ -127,17 +142,26 @@ fn main() { if let Some(cuda_home) = find_cuda_home() { emit_cuda_link_directives(&cuda_home); } - - // Write the bindings to the $OUT_DIR/bindings.rs file - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - builder - .generate() - .expect("Unable to generate bindings") - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); - println!("cargo:rustc-link-lib=cuda"); println!("cargo:rustc-link-lib=cudart"); - println!("cargo::rustc-cfg=cargo"); - println!("cargo::rustc-check-cfg=cfg(cargo)"); + + // Write the bindings to the $OUT_DIR/bindings.rs file + match env::var("OUT_DIR") { + Ok(out_dir) => { + let out_path = PathBuf::from(out_dir); + match builder.generate() { + Ok(bindings) => match bindings.write_to_file(out_path.join("bindings.rs")) { + Ok(_) => { + println!("cargo::rustc-cfg=cargo"); + println!("cargo::rustc-check-cfg=cfg(cargo)"); + } + Err(e) => eprintln!("Warning: Couldn't write bindings: {}", e), + }, + Err(e) => eprintln!("Warning: Unable to generate bindings: {}", e), + } + } + Err(_) => { + println!("Note: OUT_DIR not set, skipping bindings file generation"); + } + } } diff --git a/monarch_rdma/Cargo.toml b/monarch_rdma/Cargo.toml index 0b1025fa2..79416de9c 100644 --- a/monarch_rdma/Cargo.toml +++ b/monarch_rdma/Cargo.toml @@ -13,7 +13,7 @@ async-trait = "0.1.86" cuda-sys = { path = "../cuda-sys" } hyperactor = { version = "0.0.0", path = "../hyperactor" } rand = { version = "0.8", features = ["small_rng"] } -rdmacore-sys = { path = "../rdmacore-sys" } +rdmaxcel-sys = { path = "../rdmaxcel-sys" } serde = { version = "1.0.185", features = ["derive", "rc"] } tracing = { version = "0.1.41", features = ["attributes", "valuable"] } diff --git a/monarch_rdma/src/ibverbs_primitives.rs b/monarch_rdma/src/ibverbs_primitives.rs index 90ef5c3e4..c7c55bc1b 100644 --- a/monarch_rdma/src/ibverbs_primitives.rs +++ b/monarch_rdma/src/ibverbs_primitives.rs @@ -60,29 +60,29 @@ impl Gid { u64::from_be_bytes(self.raw[8..].try_into().unwrap()) } } -impl From for Gid { - fn from(gid: rdmacore_sys::ibv_gid) -> Self { +impl From for Gid { + fn from(gid: rdmaxcel_sys::ibv_gid) -> Self { Self { raw: unsafe { gid.raw }, } } } -impl From for rdmacore_sys::ibv_gid { +impl From for rdmaxcel_sys::ibv_gid { fn from(mut gid: Gid) -> Self { *gid.as_mut() } } -impl AsRef for Gid { - fn as_ref(&self) -> &rdmacore_sys::ibv_gid { - unsafe { &*self.raw.as_ptr().cast::() } +impl AsRef for Gid { + fn as_ref(&self) -> &rdmaxcel_sys::ibv_gid { + unsafe { &*self.raw.as_ptr().cast::() } } } -impl AsMut for Gid { - fn as_mut(&mut self) -> &mut rdmacore_sys::ibv_gid { - unsafe { &mut *self.raw.as_mut_ptr().cast::() } +impl AsMut for Gid { + fn as_mut(&mut self) -> &mut rdmaxcel_sys::ibv_gid { + unsafe { &mut *self.raw.as_mut_ptr().cast::() } } } @@ -143,7 +143,7 @@ impl Default for IbverbsConfig { max_recv_wr: 1, max_send_sge: 1, max_recv_sge: 1, - path_mtu: rdmacore_sys::IBV_MTU_1024, + path_mtu: rdmaxcel_sys::IBV_MTU_1024, retry_cnt: 7, rnr_retry: 7, qp_timeout: 14, // 4.096 μs * 2^14 = ~67 ms @@ -387,10 +387,10 @@ impl fmt::Display for RdmaPort { /// # Returns /// /// A string representation of the port state. -pub fn get_port_state_str(state: rdmacore_sys::ibv_port_state::Type) -> String { +pub fn get_port_state_str(state: rdmaxcel_sys::ibv_port_state::Type) -> String { // SAFETY: We are calling a C function that returns a C string. unsafe { - let c_str = rdmacore_sys::ibv_port_state_str(state); + let c_str = rdmaxcel_sys::ibv_port_state_str(state); if c_str.is_null() { return "Unknown".to_string(); } @@ -485,7 +485,7 @@ pub fn get_all_devices() -> Vec { // SAFETY: We are calling several C functions from libibverbs. unsafe { let mut num_devices = 0; - let device_list = rdmacore_sys::ibv_get_device_list(&mut num_devices); + let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices); if device_list.is_null() || num_devices == 0 { return devices; } @@ -496,18 +496,18 @@ pub fn get_all_devices() -> Vec { continue; } - let context = rdmacore_sys::ibv_open_device(device); + let context = rdmaxcel_sys::ibv_open_device(device); if context.is_null() { continue; } - let device_name = CStr::from_ptr(rdmacore_sys::ibv_get_device_name(device)) + let device_name = CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(device)) .to_string_lossy() .into_owned(); - let mut device_attr = rdmacore_sys::ibv_device_attr::default(); - if rdmacore_sys::ibv_query_device(context, &mut device_attr) != 0 { - rdmacore_sys::ibv_close_device(context); + let mut device_attr = rdmaxcel_sys::ibv_device_attr::default(); + if rdmaxcel_sys::ibv_query_device(context, &mut device_attr) != 0 { + rdmaxcel_sys::ibv_close_device(context); continue; } @@ -532,11 +532,11 @@ pub fn get_all_devices() -> Vec { }; for port_num in 1..=device_attr.phys_port_cnt { - let mut port_attr = rdmacore_sys::ibv_port_attr::default(); - if rdmacore_sys::ibv_query_port( + let mut port_attr = rdmaxcel_sys::ibv_port_attr::default(); + if rdmaxcel_sys::ibv_query_port( context, port_num, - &mut port_attr as *mut rdmacore_sys::ibv_port_attr as *mut _, + &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _, ) != 0 { continue; @@ -546,8 +546,8 @@ pub fn get_all_devices() -> Vec { let link_layer = get_link_layer_str(port_attr.link_layer); - let mut gid = rdmacore_sys::ibv_gid::default(); - let gid_str = if rdmacore_sys::ibv_query_gid(context, port_num, 0, &mut gid) == 0 { + let mut gid = rdmaxcel_sys::ibv_gid::default(); + let gid_str = if rdmaxcel_sys::ibv_query_gid(context, port_num, 0, &mut gid) == 0 { format_gid(&gid.raw) } else { "N/A".to_string() @@ -570,10 +570,10 @@ pub fn get_all_devices() -> Vec { } devices.push(rdma_device); - rdmacore_sys::ibv_close_device(context); + rdmaxcel_sys::ibv_close_device(context); } - rdmacore_sys::ibv_free_device_list(device_list); + rdmaxcel_sys::ibv_free_device_list(device_list); } devices @@ -592,9 +592,9 @@ pub fn ibverbs_supported() -> bool { // SAFETY: We are calling a C function from libibverbs. unsafe { let mut num_devices = 0; - let device_list = rdmacore_sys::ibv_get_device_list(&mut num_devices); + let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices); if !device_list.is_null() { - rdmacore_sys::ibv_free_device_list(device_list); + rdmaxcel_sys::ibv_free_device_list(device_list); } num_devices > 0 } @@ -670,20 +670,20 @@ pub enum RdmaOperation { Read, } -impl From for rdmacore_sys::ibv_wr_opcode::Type { +impl From for rdmaxcel_sys::ibv_wr_opcode::Type { fn from(op: RdmaOperation) -> Self { match op { - RdmaOperation::Write => rdmacore_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE, - RdmaOperation::Read => rdmacore_sys::ibv_wr_opcode::IBV_WR_RDMA_READ, + RdmaOperation::Write => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE, + RdmaOperation::Read => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ, } } } -impl From for RdmaOperation { - fn from(op: rdmacore_sys::ibv_wc_opcode::Type) -> Self { +impl From for RdmaOperation { + fn from(op: rdmaxcel_sys::ibv_wc_opcode::Type) -> Self { match op { - rdmacore_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE => RdmaOperation::Write, - rdmacore_sys::ibv_wc_opcode::IBV_WC_RDMA_READ => RdmaOperation::Read, + rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE => RdmaOperation::Write, + rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ => RdmaOperation::Read, _ => panic!("Unsupported operation type"), } } @@ -718,7 +718,7 @@ impl std::fmt::Debug for RdmaQpInfo { /// Wrapper around ibv_wc (ibverbs work completion). /// -/// This exposes only the public fields of rdmacore_sys::ibv_wc, allowing us to more easily +/// This exposes only the public fields of rdmaxcel_sys::ibv_wc, allowing us to more easily /// interact with it from Rust. Work completions are used to track the status of /// RDMA operations and are generated when an operation completes. #[derive(Debug, Named, Clone, serde::Serialize, serde::Deserialize)] @@ -730,9 +730,9 @@ pub struct IbvWc { /// `valid` - Whether the work completion is valid valid: bool, /// `error` - Error information if the operation failed - error: Option<(rdmacore_sys::ibv_wc_status::Type, u32)>, + error: Option<(rdmaxcel_sys::ibv_wc_status::Type, u32)>, /// `opcode` - Type of operation that completed (read, write, etc.) - opcode: rdmacore_sys::ibv_wc_opcode::Type, + opcode: rdmaxcel_sys::ibv_wc_opcode::Type, /// `bytes` - Immediate data (if any) bytes: Option, /// `qp_num` - Queue Pair Number @@ -749,8 +749,8 @@ pub struct IbvWc { dlid_path_bits: u8, } -impl From for IbvWc { - fn from(wc: rdmacore_sys::ibv_wc) -> Self { +impl From for IbvWc { + fn from(wc: rdmaxcel_sys::ibv_wc) -> Self { IbvWc { wr_id: wc.wr_id(), len: wc.len(), @@ -862,21 +862,21 @@ mod tests { #[test] fn test_rdma_operation_conversion() { assert_eq!( - rdmacore_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE, - rdmacore_sys::ibv_wr_opcode::Type::from(RdmaOperation::Write) + rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE, + rdmaxcel_sys::ibv_wr_opcode::Type::from(RdmaOperation::Write) ); assert_eq!( - rdmacore_sys::ibv_wr_opcode::IBV_WR_RDMA_READ, - rdmacore_sys::ibv_wr_opcode::Type::from(RdmaOperation::Read) + rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ, + rdmaxcel_sys::ibv_wr_opcode::Type::from(RdmaOperation::Read) ); assert_eq!( RdmaOperation::Write, - RdmaOperation::from(rdmacore_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE) + RdmaOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE) ); assert_eq!( RdmaOperation::Read, - RdmaOperation::from(rdmacore_sys::ibv_wc_opcode::IBV_WC_RDMA_READ) + RdmaOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ) ); } @@ -897,18 +897,18 @@ mod tests { #[test] fn test_ibv_wc() { - let mut wc = rdmacore_sys::ibv_wc::default(); + let mut wc = rdmaxcel_sys::ibv_wc::default(); // SAFETY: modifies private fields through pointer manipulation unsafe { // Cast to pointer and modify the fields directly - let wc_ptr = &mut wc as *mut rdmacore_sys::ibv_wc as *mut u8; + let wc_ptr = &mut wc as *mut rdmaxcel_sys::ibv_wc as *mut u8; // Set wr_id (at offset 0, u64) *(wc_ptr as *mut u64) = 42; // Set status to SUCCESS (at offset 8, u32) - *(wc_ptr.add(8) as *mut i32) = rdmacore_sys::ibv_wc_status::IBV_WC_SUCCESS as i32; + *(wc_ptr.add(8) as *mut i32) = rdmaxcel_sys::ibv_wc_status::IBV_WC_SUCCESS as i32; } let ibv_wc = IbvWc::from(wc); assert_eq!(ibv_wc.wr_id(), 42); diff --git a/monarch_rdma/src/rdma_components.rs b/monarch_rdma/src/rdma_components.rs index d4153f2cc..980aad059 100644 --- a/monarch_rdma/src/rdma_components.rs +++ b/monarch_rdma/src/rdma_components.rs @@ -51,8 +51,8 @@ use hyperactor::Mailbox; use hyperactor::Named; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; -/// Direct access to low-level libibverbs rdmacore_sys. -use rdmacore_sys::ibv_qp_type; +/// Direct access to low-level libibverbs rdmaxcel_sys. +use rdmaxcel_sys::ibv_qp_type; use serde::Deserialize; use serde::Serialize; @@ -201,9 +201,9 @@ impl RdmaBuffer { /// * `mr_map`: A map of memory region IDs to pointers, representing registered memory regions. /// * `counter`: A counter for generating unique memory region IDs. pub struct RdmaDomain { - pub context: *mut rdmacore_sys::ibv_context, - pub pd: *mut rdmacore_sys::ibv_pd, - mr_map: HashMap, + pub context: *mut rdmaxcel_sys::ibv_context, + pub pd: *mut rdmaxcel_sys::ibv_pd, + mr_map: HashMap, counter: u32, } @@ -219,21 +219,21 @@ impl std::fmt::Debug for RdmaDomain { } // SAFETY: -// This function contains code marked unsafe as it interacts with the Rdma device through rdmacore_sys calls. +// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls. // RdmaDomain is `Send` because the raw pointers to ibverbs structs can be // accessed from any thread, and it is safe to drop `RdmaDomain` (and run the // ibverbs destructors) from any thread. unsafe impl Send for RdmaDomain {} // SAFETY: -// This function contains code marked unsafe as it interacts with the Rdma device through rdmacore_sys calls. +// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls. // RdmaDomain is `Sync` because the underlying ibverbs APIs are thread-safe. unsafe impl Sync for RdmaDomain {} impl Drop for RdmaDomain { fn drop(&mut self) { unsafe { - rdmacore_sys::ibv_dealloc_pd(self.pd); + rdmaxcel_sys::ibv_dealloc_pd(self.pd); } } } @@ -268,7 +268,7 @@ impl RdmaDomain { pub fn new(device: RdmaDevice) -> Result { tracing::debug!("creating RdmaDomain for device {}", device.name()); // SAFETY: - // This code uses unsafe rdmacore_sys calls to interact with the RDMA device, but is safe because: + // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because: // - All pointers are properly initialized and checked for null before use // - Memory registration follows the ibverbs API contract with proper access flags // - Resources are properly cleaned up in error cases to prevent leaks @@ -277,7 +277,7 @@ impl RdmaDomain { // Get the device based on the provided RdmaDevice let device_name = device.name(); let mut num_devices = 0i32; - let devices = rdmacore_sys::ibv_get_device_list(&mut num_devices as *mut _); + let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _); if devices.is_null() || num_devices == 0 { return Err(anyhow::anyhow!("no RDMA devices found")); @@ -288,7 +288,7 @@ impl RdmaDomain { for i in 0..num_devices { let dev = *devices.offset(i as isize); let dev_name = - CStr::from_ptr(rdmacore_sys::ibv_get_device_name(dev)).to_string_lossy(); + CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy(); if dev_name == *device_name { device_ptr = dev; @@ -298,24 +298,24 @@ impl RdmaDomain { // If we didn't find the device, return an error if device_ptr.is_null() { - rdmacore_sys::ibv_free_device_list(devices); + rdmaxcel_sys::ibv_free_device_list(devices); return Err(anyhow::anyhow!("device '{}' not found", device_name)); } tracing::info!("using RDMA device: {}", device_name); // Open device - let context = rdmacore_sys::ibv_open_device(device_ptr); + let context = rdmaxcel_sys::ibv_open_device(device_ptr); if context.is_null() { - rdmacore_sys::ibv_free_device_list(devices); + rdmaxcel_sys::ibv_free_device_list(devices); let os_error = Error::last_os_error(); return Err(anyhow::anyhow!("failed to create context: {}", os_error)); } // Create protection domain - let pd = rdmacore_sys::ibv_alloc_pd(context); + let pd = rdmaxcel_sys::ibv_alloc_pd(context); if pd.is_null() { - rdmacore_sys::ibv_close_device(context); - rdmacore_sys::ibv_free_device_list(devices); + rdmaxcel_sys::ibv_close_device(context); + rdmaxcel_sys::ibv_free_device_list(devices); let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( "failed to create protection domain (PD): {}", @@ -324,7 +324,7 @@ impl RdmaDomain { } // Avoids memory leaks - rdmacore_sys::ibv_free_device_list(devices); + rdmaxcel_sys::ibv_free_device_list(devices); Ok(RdmaDomain { context, @@ -350,10 +350,10 @@ impl RdmaDomain { ); let is_cuda = err == cuda_sys::CUresult::CUDA_SUCCESS; - let access = rdmacore_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE - | rdmacore_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE - | rdmacore_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ - | rdmacore_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC; + let access = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE + | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE + | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ + | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC; let mr; if is_cuda { @@ -365,9 +365,9 @@ impl RdmaDomain { cuda_sys::CUmemRangeHandleType::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0, ); - mr = rdmacore_sys::ibv_reg_dmabuf_mr(self.pd, 0, size, 0, fd, access.0 as i32); + mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(self.pd, 0, size, 0, fd, access.0 as i32); } else { - mr = rdmacore_sys::ibv_reg_mr( + mr = rdmaxcel_sys::ibv_reg_mr( self.pd, addr as *mut std::ffi::c_void, size, @@ -396,7 +396,7 @@ impl RdmaDomain { let mr = self.mr_map.remove(&id); if mr.is_some() { unsafe { - rdmacore_sys::ibv_dereg_mr(mr.expect("mr is required")); + rdmaxcel_sys::ibv_dereg_mr(mr.expect("mr is required")); } } Ok(()) @@ -445,9 +445,9 @@ impl RdmaDomain { #[derive(Debug, Serialize, Deserialize, Named, Clone)] pub struct RdmaQueuePair { - cq: usize, // *mut rdmacore_sys::ibv_cq, - qp: usize, // *mut rdmacore_sys::ibv_qp, - context: usize, // *mut rdmacore_sys::ibv_context, + cq: usize, // *mut rdmaxcel_sys::ibv_cq, + qp: usize, // *mut rdmaxcel_sys::ibv_qp, + context: usize, // *mut rdmaxcel_sys::ibv_context, config: IbverbsConfig, } @@ -473,19 +473,19 @@ impl RdmaQueuePair { /// * Completion queue (CQ) creation fails /// * Queue pair (QP) creation fails pub fn new( - context: *mut rdmacore_sys::ibv_context, - pd: *mut rdmacore_sys::ibv_pd, + context: *mut rdmaxcel_sys::ibv_context, + pd: *mut rdmaxcel_sys::ibv_pd, config: IbverbsConfig, ) -> Result { tracing::debug!("creating an RdmaQueuePair from config {}", config); // SAFETY: - // This code uses unsafe rdmacore_sys calls to interact with the RDMA device, but is safe because: + // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because: // - All pointers are properly initialized and checked for null before use // - Resources (CQ, QP) are created following the ibverbs API contract // - Error handling properly cleans up resources in failure cases // - The operations follow the documented RDMA protocol for queue pair initialization unsafe { - let cq = rdmacore_sys::ibv_create_cq( + let cq = rdmaxcel_sys::ibv_create_cq( context, config.cq_entries, std::ptr::null_mut(), @@ -501,12 +501,12 @@ impl RdmaQueuePair { } // Create queue pair - note we currently share a CQ for both send and receive for simplicity. - let mut qp_init_attr = rdmacore_sys::ibv_qp_init_attr { + let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr { qp_context: std::ptr::null::() as *mut _, send_cq: cq, recv_cq: cq, - srq: std::ptr::null::() as *mut _, - cap: rdmacore_sys::ibv_qp_cap { + srq: std::ptr::null::() as *mut _, + cap: rdmaxcel_sys::ibv_qp_cap { max_send_wr: config.max_send_wr, max_recv_wr: config.max_recv_wr, max_send_sge: config.max_send_sge, @@ -517,9 +517,9 @@ impl RdmaQueuePair { sq_sig_all: 0, }; - let qp = rdmacore_sys::ibv_create_qp(pd, &mut qp_init_attr); + let qp = rdmaxcel_sys::ibv_create_qp(pd, &mut qp_init_attr); if qp.is_null() { - rdmacore_sys::ibv_destroy_cq(cq); + rdmaxcel_sys::ibv_destroy_cq(cq); let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( "failed to create queue pair (QP): {}", @@ -552,19 +552,19 @@ impl RdmaQueuePair { /// * GID query fails pub fn get_qp_info(&mut self) -> Result { // SAFETY: - // This code uses unsafe rdmacore_sys calls to query RDMA device information, but is safe because: + // This code uses unsafe rdmaxcel_sys calls to query RDMA device information, but is safe because: // - All pointers are properly initialized before use // - Port and GID queries follow the documented ibverbs API contract // - Error handling properly checks return codes from ibverbs functions // - The memory address provided is only stored, not dereferenced in this function unsafe { - let context = self.context as *mut rdmacore_sys::ibv_context; - let qp = self.qp as *mut rdmacore_sys::ibv_qp; - let mut port_attr = rdmacore_sys::ibv_port_attr::default(); - let errno = rdmacore_sys::ibv_query_port( + let context = self.context as *mut rdmaxcel_sys::ibv_context; + let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let mut port_attr = rdmaxcel_sys::ibv_port_attr::default(); + let errno = rdmaxcel_sys::ibv_query_port( context, self.config.port_num, - &mut port_attr as *mut rdmacore_sys::ibv_port_attr as *mut _, + &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _, ); if errno != 0 { let os_error = Error::last_os_error(); @@ -575,7 +575,7 @@ impl RdmaQueuePair { } let mut gid = Gid::default(); - let ret = rdmacore_sys::ibv_query_gid( + let ret = rdmaxcel_sys::ibv_query_gid( context, self.config.port_num, i32::from(self.config.gid_index), @@ -595,18 +595,18 @@ impl RdmaQueuePair { } pub fn state(&mut self) -> Result { - // SAFETY: This block interacts with the RDMA device through rdmacore_sys calls. + // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls. unsafe { - let qp = self.qp as *mut rdmacore_sys::ibv_qp; - let mut qp_attr = rdmacore_sys::ibv_qp_attr { + let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let mut qp_attr = rdmaxcel_sys::ibv_qp_attr { ..Default::default() }; - let mut qp_init_attr = rdmacore_sys::ibv_qp_init_attr { + let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr { ..Default::default() }; - let mask = rdmacore_sys::ibv_qp_attr_mask::IBV_QP_STATE; + let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE; let errno = - rdmacore_sys::ibv_query_qp(qp, &mut qp_attr, mask.0 as i32, &mut qp_init_attr); + rdmaxcel_sys::ibv_query_qp(qp, &mut qp_attr, mask.0 as i32, &mut qp_init_attr); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!("failed to query QP state: {}", os_error)); @@ -623,7 +623,7 @@ impl RdmaQueuePair { /// * `connection_info` - The remote connection info to connect to pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> { // SAFETY: - // This unsafe block is necessary because we're interacting with the RDMA device through rdmacore_sys calls. + // This unsafe block is necessary because we're interacting with the RDMA device through rdmaxcel_sys calls. // The operations are safe because: // 1. We're following the documented ibverbs API contract // 2. All pointers used are properly initialized and owned by this struct @@ -631,26 +631,26 @@ impl RdmaQueuePair { // 4. Memory access is properly bounded by the registered memory regions unsafe { // Transition to INIT - let qp = self.qp as *mut rdmacore_sys::ibv_qp; + let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; - let qp_access_flags = rdmacore_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE - | rdmacore_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE - | rdmacore_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ; + let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE + | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE + | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ; - let mut qp_attr = rdmacore_sys::ibv_qp_attr { - qp_state: rdmacore_sys::ibv_qp_state::IBV_QPS_INIT, + let mut qp_attr = rdmaxcel_sys::ibv_qp_attr { + qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT, qp_access_flags: qp_access_flags.0, pkey_index: self.config.pkey_index, port_num: self.config.port_num, ..Default::default() }; - let mask = rdmacore_sys::ibv_qp_attr_mask::IBV_QP_STATE - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_PORT - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS; + let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS; - let errno = rdmacore_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -660,14 +660,14 @@ impl RdmaQueuePair { } // Transition to RTR (Ready to Receive) - let mut qp_attr = rdmacore_sys::ibv_qp_attr { - qp_state: rdmacore_sys::ibv_qp_state::IBV_QPS_RTR, + let mut qp_attr = rdmaxcel_sys::ibv_qp_attr { + qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR, path_mtu: self.config.path_mtu, dest_qp_num: connection_info.qp_num, rq_psn: connection_info.psn, max_dest_rd_atomic: self.config.max_dest_rd_atomic, min_rnr_timer: self.config.min_rnr_timer, - ah_attr: rdmacore_sys::ibv_ah_attr { + ah_attr: rdmaxcel_sys::ibv_ah_attr { dlid: connection_info.lid, sl: 0, src_path_bits: 0, @@ -690,15 +690,15 @@ impl RdmaQueuePair { qp_attr.ah_attr.is_global = 0; } - let mask = rdmacore_sys::ibv_qp_attr_mask::IBV_QP_STATE - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_AV - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER; + let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER; - let errno = rdmacore_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -708,8 +708,8 @@ impl RdmaQueuePair { } // Transition to RTS (Ready to Send) - let mut qp_attr = rdmacore_sys::ibv_qp_attr { - qp_state: rdmacore_sys::ibv_qp_state::IBV_QPS_RTS, + let mut qp_attr = rdmaxcel_sys::ibv_qp_attr { + qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS, sq_psn: self.config.psn, max_rd_atomic: self.config.max_rd_atomic, retry_cnt: self.config.retry_cnt, @@ -718,14 +718,14 @@ impl RdmaQueuePair { ..Default::default() }; - let mask = rdmacore_sys::ibv_qp_attr_mask::IBV_QP_STATE - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY - | rdmacore_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC; + let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY + | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC; - let errno = rdmacore_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); + let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32); if errno != 0 { let os_error = Error::last_os_error(); return Err(anyhow::anyhow!( @@ -795,28 +795,28 @@ impl RdmaQueuePair { rkey: u32, ) -> Result<(), anyhow::Error> { // SAFETY: - // This code uses unsafe rdmacore_sys calls to post work requests to the RDMA device, but is safe because: + // This code uses unsafe rdmaxcel_sys calls to post work requests to the RDMA device, but is safe because: // - All pointers (send_sge, send_wr) are properly initialized on the stack before use // - The memory address in `local_addr` is not dereferenced, only passed to the device // - The remote connection info is verified to exist before accessing // - The ibverbs post_send operation follows the documented API contract // - Error codes from the device are properly checked and propagated unsafe { - let qp = self.qp as *mut rdmacore_sys::ibv_qp; - let context = self.context as *mut rdmacore_sys::ibv_context; - let mut send_sge = rdmacore_sys::ibv_sge { + let qp = self.qp as *mut rdmaxcel_sys::ibv_qp; + let context = self.context as *mut rdmaxcel_sys::ibv_context; + let mut send_sge = rdmaxcel_sys::ibv_sge { addr: laddr as u64, length: length as u32, lkey, }; let send_flags = if signaled { - rdmacore_sys::ibv_send_flags::IBV_SEND_SIGNALED.0 + rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0 } else { 0 }; - let mut send_wr = rdmacore_sys::ibv_send_wr { + let mut send_wr = rdmaxcel_sys::ibv_send_wr { wr_id, next: std::ptr::null_mut(), sg_list: &mut send_sge as *mut _, @@ -832,7 +832,7 @@ impl RdmaQueuePair { // Set remote address and rkey for RDMA operations send_wr.wr.rdma.remote_addr = raddr as u64; send_wr.wr.rdma.rkey = rkey; - let mut bad_send_wr: *mut rdmacore_sys::ibv_send_wr = std::ptr::null_mut(); + let mut bad_send_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut(); let ops = &mut (*context).ops; let errno = ops.post_send.as_mut().unwrap()(qp, &mut send_wr as *mut _, &mut bad_send_wr); @@ -873,16 +873,16 @@ impl RdmaQueuePair { /// * `Err(e)` - An error occurred pub fn poll_completion(&self) -> Result, anyhow::Error> { // SAFETY: - // This code uses unsafe rdmacore_sys calls to poll the completion queue, but is safe because: + // This code uses unsafe rdmaxcel_sys calls to poll the completion queue, but is safe because: // - The completion queue pointer is properly initialized and owned by this struct // - The work completion structure is properly zeroed before use // - We only access the completion queue through the documented ibverbs API // - Error codes from polling operations are properly checked and propagated // - The work completion validity is verified before returning it to the caller unsafe { - let context = self.context as *mut rdmacore_sys::ibv_context; - let cq = self.cq as *mut rdmacore_sys::ibv_cq; - let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); + let context = self.context as *mut rdmaxcel_sys::ibv_context; + let cq = self.cq as *mut rdmaxcel_sys::ibv_cq; + let mut wc = std::mem::MaybeUninit::::zeroed().assume_init(); let ops = &mut (*context).ops; let ret = ops.poll_cq.as_mut().unwrap()(cq, 1, &mut wc); diff --git a/monarch_rdma/src/rdma_manager_actor.rs b/monarch_rdma/src/rdma_manager_actor.rs index f4b32cce5..d32dc877a 100644 --- a/monarch_rdma/src/rdma_manager_actor.rs +++ b/monarch_rdma/src/rdma_manager_actor.rs @@ -302,7 +302,7 @@ impl RdmaManagerMessageHandler for RdmaManagerActor { .get_mut(&other.actor_id().clone()) .unwrap() .state()?; - Ok(qp_state == rdmacore_sys::ibv_qp_state::IBV_QPS_RTS) + Ok(qp_state == rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS) } /// Establishes a connection with another actor diff --git a/rdmacore-sys/build.rs b/rdmacore-sys/build.rs deleted file mode 100644 index f3a53622b..000000000 --- a/rdmacore-sys/build.rs +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::env; -use std::path::PathBuf; - -fn main() { - // Tell cargo to look for shared libraries in the specified directory - println!("cargo:rustc-link-search=/usr/lib"); - println!("cargo:rustc-link-search=/usr/lib64"); - - // Link against the ibverbs library - println!("cargo:rustc-link-lib=ibverbs"); - - // Link against the mlx5 library - println!("cargo:rustc-link-lib=mlx5"); - - // Tell cargo to invalidate the built crate whenever the wrapper changes - println!("cargo:rerun-if-changed=src/wrapper.h"); - - // Add cargo metadata - println!("cargo:rustc-cfg=cargo"); - println!("cargo:rustc-check-cfg=cfg(cargo)"); - - // The bindgen::Builder is the main entry point to bindgen - let bindings = bindgen::Builder::default() - // The input header we would like to generate bindings for - .header("src/wrapper.h") - // Allow the specified functions, types, and variables - .allowlist_function("ibv_.*") - .allowlist_function("mlx5dv_.*") - .allowlist_function("mlx5_wqe_.*") - .allowlist_type("ibv_.*") - .allowlist_type("mlx5dv_.*") - .allowlist_type("mlx5_wqe_.*") - .allowlist_var("MLX5_.*") - // Block specific types that are manually defined in lib.rs - .blocklist_type("ibv_wc") - .blocklist_type("mlx5_wqe_ctrl_seg") - // Apply the same bindgen flags as in the BUCK file - .bitfield_enum("ibv_access_flags") - .bitfield_enum("ibv_qp_attr_mask") - .bitfield_enum("ibv_wc_flags") - .bitfield_enum("ibv_send_flags") - .bitfield_enum("ibv_port_cap_flags") - .constified_enum_module("ibv_qp_type") - .constified_enum_module("ibv_qp_state") - .constified_enum_module("ibv_port_state") - .constified_enum_module("ibv_wc_opcode") - .constified_enum_module("ibv_wr_opcode") - .constified_enum_module("ibv_wc_status") - .derive_default(true) - .prepend_enum_name(false) - // Finish the builder and generate the bindings - .generate() - // Unwrap the Result and panic on failure - .expect("Unable to generate bindings"); - - // Write the bindings to the $OUT_DIR/bindings.rs file - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); -} diff --git a/rdmacore-sys/Cargo.toml b/rdmaxcel-sys/Cargo.toml similarity index 68% rename from rdmacore-sys/Cargo.toml rename to rdmaxcel-sys/Cargo.toml index 0ac03be8c..021103810 100644 --- a/rdmacore-sys/Cargo.toml +++ b/rdmaxcel-sys/Cargo.toml @@ -1,10 +1,10 @@ [package] -name = "rdmacore-sys" -version = "0.0.0" +name = "rdmaxcel-sys" +version = "0.0.1" authors = ["Facebook"] edition = "2021" license = "MIT" -links = "ibverbs" +links = "rdmaxcel" [dependencies] cxx = "1.0.119" @@ -12,3 +12,5 @@ serde = { version = "1.0.185", features = ["derive", "rc"] } [build-dependencies] bindgen = "0.70.1" +which = "6.0.3" +glob = "0.3.1" diff --git a/rdmaxcel-sys/build.rs b/rdmaxcel-sys/build.rs new file mode 100644 index 000000000..a81f1d103 --- /dev/null +++ b/rdmaxcel-sys/build.rs @@ -0,0 +1,244 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::env; +use std::path::Path; +use std::path::PathBuf; + +use glob::glob; +use which::which; + +const PYTHON_PRINT_DIRS: &str = r" +import sysconfig +print('PYTHON_INCLUDE_DIR:', sysconfig.get_config_var('INCLUDEDIR')) +print('PYTHON_LIB_DIR:', sysconfig.get_config_var('LIBDIR')) +"; + +// Translated from torch/utils/cpp_extension.py +fn find_cuda_home() -> Option { + // Guess #1 + let mut cuda_home = env::var("CUDA_HOME") + .ok() + .or_else(|| env::var("CUDA_PATH").ok()); + + if cuda_home.is_none() { + // Guess #2 + if let Ok(nvcc_path) = which("nvcc") { + // Get parent directory twice + if let Some(cuda_dir) = nvcc_path.parent().and_then(|p| p.parent()) { + cuda_home = Some(cuda_dir.to_string_lossy().into_owned()); + } + } else { + // Guess #3 + if cfg!(windows) { + // Windows code + let pattern = r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*.*"; + let cuda_homes: Vec<_> = glob(pattern).unwrap().filter_map(Result::ok).collect(); + if !cuda_homes.is_empty() { + cuda_home = Some(cuda_homes[0].to_string_lossy().into_owned()); + } else { + cuda_home = None; + } + } else { + // Walk through possible locations, starting with newest + for candidate in &[ + "/usr/local/cuda-12.8", + "/usr/local/cuda-12.6", + "/usr/local/cuda-12.4", + "/usr/local/cuda-12.2", + "/usr/local/cuda-12.1", + "/usr/local/cuda-12.0", + "/usr/local/cuda-11.8", + "/usr/local/cuda-11.7", + "/usr/local/cuda-11.6", + "/usr/local/cuda-11.5", + ] { + if Path::new(candidate).exists() { + cuda_home = Some(candidate.to_string()); + break; + } + } + } + } + } + cuda_home +} + +fn emit_cuda_link_directives(cuda_home: &str) { + let stubs_path = format!("{}/lib64/stubs", cuda_home); + if Path::new(&stubs_path).exists() { + println!("cargo:rustc-link-search=native={}", stubs_path); + } else { + let lib64_path = format!("{}/lib64", cuda_home); + if Path::new(&lib64_path).exists() { + println!("cargo:rustc-link-search=native={}", lib64_path); + } + } + + println!("cargo:rustc-link-lib=cuda"); + println!("cargo:rustc-link-lib=cudart"); +} + +fn python_env_dirs() -> (Option, Option) { + let output = std::process::Command::new(PathBuf::from("python3")) + .arg("-c") + .arg(PYTHON_PRINT_DIRS) + .output() + .unwrap_or_else(|_| panic!("error running python")); + + let mut include_dir = None; + let mut lib_dir = None; + for line in String::from_utf8_lossy(&output.stdout).lines() { + if let Some(path) = line.strip_prefix("PYTHON_INCLUDE_DIR: ") { + include_dir = Some(path.to_string()); + } + if let Some(path) = line.strip_prefix("PYTHON_LIB_DIR: ") { + lib_dir = Some(path.to_string()); + } + } + (include_dir, lib_dir) +} + +fn main() { + // Tell cargo to look for shared libraries in the specified directory + println!("cargo:rustc-link-search=/usr/lib"); + println!("cargo:rustc-link-search=/usr/lib64"); + + // Link against the ibverbs library + println!("cargo:rustc-link-lib=ibverbs"); + + // Link against the mlx5 library + println!("cargo:rustc-link-lib=mlx5"); + + // Tell cargo to invalidate the built crate whenever the wrapper changes + println!("cargo:rerun-if-changed=src/rdmaxcel.h"); + + // Get the directory of the current crate + let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| { + // For buck2 run, we know the package is in fbcode/monarch/rdmaxcel-sys + // Get the fbsource directory from the current directory path + let current_dir = std::env::current_dir().expect("Failed to get current directory"); + let current_path = current_dir.to_string_lossy(); + + // Find the fbsource part of the path + if let Some(fbsource_pos) = current_path.find("fbsource") { + let fbsource_path = ¤t_path[..fbsource_pos + "fbsource".len()]; + format!("{}/fbcode/monarch/rdmaxcel-sys", fbsource_path) + } else { + // If we can't find fbsource in the path, just use the current directory + format!("{}/src", current_dir.to_string_lossy()) + } + }); + + // Create the absolute path to the header file + let header_path = format!("{}/src/rdmaxcel.h", manifest_dir); + + // Check if the header file exists + if !Path::new(&header_path).exists() { + panic!("Header file not found at {}", header_path); + } + + // Start building the bindgen configuration + let mut builder = bindgen::Builder::default() + // The input header we would like to generate bindings for + .header(&header_path) + .clang_arg("-x") + .clang_arg("c++") + .clang_arg("-std=gnu++20") + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + // Allow the specified functions, types, and variables + .allowlist_function("ibv_.*") + .allowlist_function("mlx5dv_.*") + .allowlist_function("mlx5_wqe_.*") + .allowlist_function("create_qp") + .allowlist_function("create_mlx5dv_.*") + .allowlist_function("register_cuda_memory") + .allowlist_function("db_ring") + .allowlist_function("cqe_poll") + .allowlist_function("send_wqe") + .allowlist_function("recv_wqe") + .allowlist_function("launch_db_ring") + .allowlist_function("launch_cqe_poll") + .allowlist_function("launch_send_wqe") + .allowlist_function("launch_recv_wqe") + .allowlist_type("ibv_.*") + .allowlist_type("mlx5dv_.*") + .allowlist_type("mlx5_wqe_.*") + .allowlist_type("cqe_poll_result_t") + .allowlist_type("wqe_params_t") + .allowlist_type("cqe_poll_params_t") + .allowlist_var("MLX5_.*") + .allowlist_var("IBV_.*") + // Block specific types that are manually defined in lib.rs + .blocklist_type("ibv_wc") + .blocklist_type("mlx5_wqe_ctrl_seg") + // Apply the same bindgen flags as in the BUCK file + .bitfield_enum("ibv_access_flags") + .bitfield_enum("ibv_qp_attr_mask") + .bitfield_enum("ibv_wc_flags") + .bitfield_enum("ibv_send_flags") + .bitfield_enum("ibv_port_cap_flags") + .constified_enum_module("ibv_qp_type") + .constified_enum_module("ibv_qp_state") + .constified_enum_module("ibv_port_state") + .constified_enum_module("ibv_wc_opcode") + .constified_enum_module("ibv_wr_opcode") + .constified_enum_module("ibv_wc_status") + .derive_default(true) + .prepend_enum_name(false); + + // Add CUDA include path if available + if let Some(cuda_home) = find_cuda_home() { + let cuda_include_path = format!("{}/include", cuda_home); + if Path::new(&cuda_include_path).exists() { + builder = builder.clang_arg(format!("-I{}", cuda_include_path)); + } else { + eprintln!( + "Warning: CUDA include directory not found at {}", + cuda_include_path + ); + } + } else { + eprintln!("Warning: CUDA home directory not found. Continuing without CUDA include path."); + } + + // Include headers and libs from the active environment. + let (include_dir, lib_dir) = python_env_dirs(); + if let Some(include_dir) = include_dir { + builder = builder.clang_arg(format!("-I{}", include_dir)); + } + if let Some(lib_dir) = lib_dir { + println!("cargo::rustc-link-search=native={}", lib_dir); + // Set cargo metadata to inform dependent binaries about how to set their + // RPATH (see controller/build.rs for an example). + println!("cargo::metadata=LIB_PATH={}", lib_dir); + } + if let Some(cuda_home) = find_cuda_home() { + emit_cuda_link_directives(&cuda_home); + } + + // Generate bindings + let bindings = builder.generate().expect("Unable to generate bindings"); + + // Write the bindings to the $OUT_DIR/bindings.rs file + match env::var("OUT_DIR") { + Ok(out_dir) => { + let out_path = PathBuf::from(out_dir); + match bindings.write_to_file(out_path.join("bindings.rs")) { + Ok(_) => { + println!("cargo:rustc-cfg=cargo"); + println!("cargo:rustc-check-cfg=cfg(cargo)"); + } + Err(e) => eprintln!("Warning: Couldn't write bindings: {}", e), + } + } + Err(_) => { + println!("Note: OUT_DIR not set, skipping bindings file generation"); + } + } +} diff --git a/rdmacore-sys/src/lib.rs b/rdmaxcel-sys/src/lib.rs similarity index 100% rename from rdmacore-sys/src/lib.rs rename to rdmaxcel-sys/src/lib.rs diff --git a/rdmaxcel-sys/src/rdmaxcel.c b/rdmaxcel-sys/src/rdmaxcel.c new file mode 100644 index 000000000..5569a914d --- /dev/null +++ b/rdmaxcel-sys/src/rdmaxcel.c @@ -0,0 +1,207 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "rdmaxcel.h" + +#include +#include +#include + +cudaError_t register_mmio_to_cuda(void* bf, size_t size) { + cudaError_t result = cudaHostRegister( + bf, + size, + cudaHostRegisterMapped | cudaHostRegisterPortable | + cudaHostRegisterIoMemory); + return result; +} + +struct ibv_qp* create_qp( + struct ibv_context* context, + struct ibv_pd* pd, + int cq_entries, + int max_send_wr, + int max_recv_wr, + int max_send_sge, + int max_recv_sge) { + // Create separate completion queues for send and receive operations + struct ibv_cq* send_cq = ibv_create_cq(context, cq_entries, NULL, NULL, 0); + if (!send_cq) { + perror("failed to create send completion queue (CQ)"); + return NULL; + } + + struct ibv_cq* recv_cq = ibv_create_cq(context, cq_entries, NULL, NULL, 0); + if (!recv_cq) { + perror("failed to create receive completion queue (CQ)"); + ibv_destroy_cq(send_cq); + return NULL; + } + + // Initialize queue pair attributes + struct ibv_qp_init_attr qp_init_attr = { + .qp_context = NULL, + .send_cq = send_cq, + .recv_cq = recv_cq, + .srq = NULL, + .cap = + { + .max_send_wr = max_send_wr, + .max_recv_wr = max_recv_wr, + .max_send_sge = max_send_sge, + .max_recv_sge = max_recv_sge, + .max_inline_data = 0, + }, + .qp_type = IBV_QPT_RC, + .sq_sig_all = 0, + }; + + // Create queue pair + struct ibv_qp* qp = ibv_create_qp(pd, &qp_init_attr); + if (!qp) { + perror("failed to create queue pair (QP)"); + ibv_destroy_cq(send_cq); + ibv_destroy_cq(recv_cq); + return NULL; + } + + return qp; +} + +struct mlx5dv_qp* create_mlx5dv_qp(struct ibv_qp* qp) { + struct mlx5dv_qp* dv_qp = malloc(sizeof(struct mlx5dv_qp)); + struct mlx5dv_obj dv_obj; + memset(&dv_obj, 0, sizeof(dv_obj)); + memset(dv_qp, 0, sizeof(*dv_qp)); + + dv_obj.qp.in = qp; + dv_obj.qp.out = dv_qp; + int ret = mlx5dv_init_obj(&dv_obj, MLX5DV_OBJ_QP); + if (ret != 0) { + perror("failed to init mlx5dv_qp"); + free(dv_qp); + return NULL; + } + + return dv_qp; +} + +struct mlx5dv_cq* create_mlx5dv_cq(struct ibv_qp* qp) { + // We'll use the receive CQ for now, but in the future this will be updated + // to handle both send and receive CQs separately + struct mlx5dv_cq* dv_cq = malloc(sizeof(struct mlx5dv_cq)); + struct mlx5dv_obj dv_obj; + memset(&dv_obj, 0, sizeof(dv_obj)); + memset(dv_cq, 0, sizeof(*dv_cq)); + + dv_obj.cq.in = qp->recv_cq; + dv_obj.cq.out = dv_cq; + int ret = mlx5dv_init_obj(&dv_obj, MLX5DV_OBJ_CQ); + if (ret != 0) { + perror("failed to init mlx5dv_cq"); + free(dv_cq); + return NULL; + } + return dv_cq; +} + +struct mlx5dv_cq* create_mlx5dv_send_cq(struct ibv_qp* qp) { + struct mlx5dv_cq* dv_cq = malloc(sizeof(struct mlx5dv_cq)); + struct mlx5dv_obj dv_obj; + memset(&dv_obj, 0, sizeof(dv_obj)); + memset(dv_cq, 0, sizeof(*dv_cq)); + + dv_obj.cq.in = qp->send_cq; + dv_obj.cq.out = dv_cq; + int ret = mlx5dv_init_obj(&dv_obj, MLX5DV_OBJ_CQ); + if (ret != 0) { + perror("failed to init mlx5dv_send_cq"); + free(dv_cq); + return NULL; + } + return dv_cq; +} + +struct mlx5dv_cq* create_mlx5dv_recv_cq(struct ibv_qp* qp) { + struct mlx5dv_cq* dv_cq = malloc(sizeof(struct mlx5dv_cq)); + struct mlx5dv_obj dv_obj; + memset(&dv_obj, 0, sizeof(dv_obj)); + memset(dv_cq, 0, sizeof(*dv_cq)); + + dv_obj.cq.in = qp->recv_cq; + dv_obj.cq.out = dv_cq; + int ret = mlx5dv_init_obj(&dv_obj, MLX5DV_OBJ_CQ); + if (ret != 0) { + perror("failed to init mlx5dv_recv_cq"); + free(dv_cq); + return NULL; + } + return dv_cq; +} + +cudaError_t register_cuda_memory( + struct mlx5dv_qp* dv_qp, + struct mlx5dv_cq* dv_recv_cq, + struct mlx5dv_cq* dv_send_cq) { + cudaError_t ret; + + ret = cudaHostRegister( + dv_qp->sq.buf, + dv_qp->sq.stride * dv_qp->sq.wqe_cnt, + cudaHostRegisterMapped | cudaHostRegisterPortable); + if (ret != cudaSuccess) { + return ret; + } + + ret = cudaHostRegister( + dv_qp->bf.reg, + dv_qp->bf.size, + cudaHostRegisterMapped | cudaHostRegisterPortable | + cudaHostRegisterIoMemory); + if (ret != cudaSuccess) { + return ret; + } + + ret = cudaHostRegister( + dv_qp->dbrec, 8, cudaHostRegisterMapped | cudaHostRegisterPortable); + if (ret != cudaSuccess) { + return ret; + } + + // Register receive completion queue + ret = cudaHostRegister( + dv_recv_cq->buf, + dv_recv_cq->cqe_size * dv_recv_cq->cqe_cnt, + cudaHostRegisterMapped | cudaHostRegisterPortable); + if (ret != cudaSuccess) { + return ret; + } + + ret = cudaHostRegister( + dv_recv_cq->dbrec, 4, cudaHostRegisterMapped | cudaHostRegisterPortable); + if (ret != cudaSuccess) { + return ret; + } + + // Register send completion queue + ret = cudaHostRegister( + dv_send_cq->buf, + dv_send_cq->cqe_size * dv_send_cq->cqe_cnt, + cudaHostRegisterMapped | cudaHostRegisterPortable); + if (ret != cudaSuccess) { + return ret; + } + + ret = cudaHostRegister( + dv_send_cq->dbrec, 4, cudaHostRegisterMapped | cudaHostRegisterPortable); + if (ret != cudaSuccess) { + return ret; + } + + return cudaSuccess; +} diff --git a/rdmaxcel-sys/src/rdmaxcel.cu b/rdmaxcel-sys/src/rdmaxcel.cu new file mode 100644 index 000000000..4f252ea2e --- /dev/null +++ b/rdmaxcel-sys/src/rdmaxcel.cu @@ -0,0 +1,420 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include "rdmaxcel.h" + +//------------------------------------------------------------------------------ +// Byte Swapping Utilities +//------------------------------------------------------------------------------ + +/** + * @brief Swaps the byte order of a 16-bit value (converts between little and + * big endian) + * + * This function is used for endianness conversion when communicating with + * InfiniBand hardware, which uses big-endian byte ordering. + * + * @param val The 16-bit value to swap + * @return The byte-swapped value + */ +__host__ __device__ static inline uint16_t byte_swap16(uint16_t val) { + return ((val & 0xFF00) >> 8) | ((val & 0x00FF) << 8); +} + +/** + * @brief Swaps the byte order of a 32-bit value (converts between little and + * big endian) + * + * This function is used for endianness conversion when communicating with + * InfiniBand hardware, which uses big-endian byte ordering. + * + * @param val The 32-bit value to swap + * @return The byte-swapped value + */ +__host__ __device__ static inline uint32_t byte_swap32(uint32_t val) { + return ((val & 0xFF000000) >> 24) | ((val & 0x00FF0000) >> 8) | + ((val & 0x0000FF00) << 8) | ((val & 0x000000FF) << 24); +} + +/** + * @brief Swaps the byte order of a 64-bit value (converts between little and + * big endian) + * + * This function is used for endianness conversion when communicating with + * InfiniBand hardware, which uses big-endian byte ordering. + * + * @param val The 64-bit value to swap + * @return The byte-swapped value + */ +__host__ __device__ static inline uint64_t byte_swap64(uint64_t val) { + return ((val & 0xFF00000000000000ULL) >> 56) | + ((val & 0x00FF000000000000ULL) >> 40) | + ((val & 0x0000FF0000000000ULL) >> 24) | + ((val & 0x000000FF00000000ULL) >> 8) | + ((val & 0x00000000FF000000ULL) << 8) | + ((val & 0x0000000000FF0000ULL) << 24) | + ((val & 0x000000000000FF00ULL) << 40) | + ((val & 0x00000000000000FFULL) << 56); +} + +//------------------------------------------------------------------------------ +// Doorbell Operations +//------------------------------------------------------------------------------ + +/** + * @brief Rings a doorbell by copying 8 64-bit values from source to destination + * + * This function is used to notify the HCA (Host Channel Adapter) that new work + * has been queued. It copies 8 64-bit values (64 bytes total) from the source + * to the destination, which is typically a memory-mapped doorbell register. + * + * @param dst Pointer to the destination (doorbell register) + * @param src Pointer to the source data + */ +__host__ __device__ void db_ring(void* dst, void* src) { + volatile uint64_t* dst_v = (uint64_t*)dst; + volatile uint64_t* src_v = (uint64_t*)src; + dst_v[0] = src_v[0]; + dst_v[1] = src_v[1]; + dst_v[2] = src_v[2]; + dst_v[3] = src_v[3]; + dst_v[4] = src_v[4]; + dst_v[5] = src_v[5]; + dst_v[6] = src_v[6]; + dst_v[7] = src_v[7]; +} + +/** + * @brief CUDA kernel wrapper for db_ring function + * + * This kernel launches a single thread to execute the db_ring function on the + * GPU. It includes memory fences to ensure proper ordering of memory + * operations. + * + * @param dst Pointer to the destination (doorbell register) + * @param src Pointer to the source data + */ +__global__ void cu_db_ring(void* dst, void* src) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i == 0) { + db_ring(dst, src); + } + __syncthreads(); + __threadfence_system(); +} + +/** + * @brief Host function to launch the cu_db_ring kernel + * + * This function launches the cu_db_ring kernel with a single thread. + * + * @param dst Pointer to the destination (doorbell register) + * @param src Pointer to the source data + */ +void launch_db_ring(void* dst, void* src) { + cu_db_ring<<<1, 1>>>(dst, src); +} + +//------------------------------------------------------------------------------ +// Work Queue Element (WQE) Operations +//------------------------------------------------------------------------------ + +/** + * @brief Creates and posts a receive WQE (Work Queue Element) + * + * This function creates a receive WQE with the specified parameters and posts + * it to the receive queue. For MLX5 receive WQEs, it creates a data segment and + * updates the doorbell record. + * + * @param params Structure containing all parameters needed for the receive WQE + */ +__host__ __device__ void recv_wqe(wqe_params_t params) { + // For MLX5 receive WQEs, we need to create a proper structure with: + // 1. A next segment (mlx5_wqe_srq_next_seg) + // 2. A data segment (mlx5_wqe_data_seg) + + // Declare individual segments instead of using the combined struct + struct mlx5_wqe_data_seg data_seg; + + // Initialize the data segment + data_seg.byte_count = byte_swap32(params.length); + data_seg.lkey = byte_swap32(params.lkey); + data_seg.addr = byte_swap64(params.laddr); + + // Calculate pointers for segments + uintptr_t data_seg_ptr = (uintptr_t)params.buf; + + // Copy segments to WQE buffer + memcpy((void*)data_seg_ptr, &data_seg, sizeof(data_seg)); + + volatile uint32_t* dbrec = params.dbrec; // Declare a volatile pointer + dbrec[MLX5_RCV_DBR] = byte_swap32(params.wr_id + 1); +} + +/** + * @brief CUDA kernel wrapper for recv_wqe function + * + * This kernel launches a single thread to execute the recv_wqe function on the + * GPU. + * + * @param params Structure containing all parameters needed for the receive WQE + */ +__global__ void cu_recv_wqe(wqe_params_t params) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + recv_wqe(params); + } +} + +/** + * @brief Host function to launch the cu_recv_wqe kernel + * + * This function launches the cu_recv_wqe kernel with a single thread and + * synchronizes the device to ensure completion. + * + * @param params Structure containing all parameters needed for the receive WQE + */ +void launch_recv_wqe(wqe_params_t params) { + // Launch kernel + cu_recv_wqe<<<1, 1>>>(params); + + // Wait for kernel to complete + cudaDeviceSynchronize(); +} + +/** + * @brief Creates and posts a send WQE (Work Queue Element) + * + * This function creates a send WQE with the specified parameters and posts it + * to the send queue. It creates control, remote address, and data segments, + * and updates the doorbell record. + * + * @param params Structure containing all parameters needed for the send WQE + */ +__host__ __device__ void send_wqe(wqe_params_t params) { + struct mlx5_wqe_ctrl_seg ctrl_seg = {0}; + struct mlx5_wqe_data_seg data_seg = {0}; + struct mlx5_wqe_raddr_seg raddr_seg = {0}; + + uint32_t idx = params.wr_id; + uint32_t buffer_idx = idx & (params.wqe_cnt - 1); + + // Set control segment + ctrl_seg.fm_ce_se = + params.signaled ? MLX5_WQE_CTRL_CQ_UPDATE | MLX5_WQE_CTRL_SOLICITED : 0; + + // Set opcode based on operation type + ctrl_seg.opmod_idx_opcode = ((idx << 8) | params.op_type); + + // Convert to big endian + ctrl_seg.opmod_idx_opcode = byte_swap32(ctrl_seg.opmod_idx_opcode); + + // Set QP number and data size (48 bytes / 16 = 3 DS) + ctrl_seg.qpn_ds = (params.qp_num << 8 | (48 / 16)); + ctrl_seg.qpn_ds = byte_swap32(ctrl_seg.qpn_ds); + + // Set remote address segment + raddr_seg.raddr = byte_swap64(params.raddr); + raddr_seg.rkey = byte_swap32(params.rkey); + + // Set data segment + data_seg.addr = byte_swap64(params.laddr); + data_seg.byte_count = byte_swap32(params.length); + data_seg.lkey = byte_swap32(params.lkey); + + // Calculate pointers for segments + uintptr_t ctrl_seg_ptr = + (uintptr_t)(params.buf) + (buffer_idx << MLX5_SEND_WQE_SHIFT); + uintptr_t raddr_seg_ptr = ctrl_seg_ptr + sizeof(ctrl_seg); + uintptr_t data_seg_ptr = raddr_seg_ptr + sizeof(raddr_seg); + + // Copy segments to WQE buffer + memcpy((void*)ctrl_seg_ptr, &ctrl_seg, sizeof(ctrl_seg)); + memcpy((void*)raddr_seg_ptr, &raddr_seg, sizeof(raddr_seg)); + memcpy((void*)data_seg_ptr, &data_seg, sizeof(data_seg)); + + volatile uint32_t* dbrec = params.dbrec; + dbrec[MLX5_SND_DBR] = byte_swap32((idx + 1) & 0xFFFFFF); +} + +/** + * @brief CUDA kernel wrapper for send_wqe function + * + * This kernel launches a single thread to execute the send_wqe function on the + * GPU. + * + * @param params Structure containing all parameters needed for the send WQE + */ +__global__ void cu_send_wqe(wqe_params_t params) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + send_wqe(params); + } +} + +/** + * @brief Host function to launch the cu_send_wqe kernel + * + * This function launches the cu_send_wqe kernel with a single thread and + * synchronizes the device to ensure completion. + * + * @param params Structure containing all parameters needed for the send WQE + */ +void launch_send_wqe(wqe_params_t params) { + // Launch kernel + cu_send_wqe<<<1, 1>>>(params); + + // Wait for kernel to complete + cudaDeviceSynchronize(); +} + +//------------------------------------------------------------------------------ +// Completion Queue Element (CQE) Operations +//------------------------------------------------------------------------------ + +/** + * @brief Polls a completion queue for a new completion + * + * This function checks if there is a new completion in the completion queue. + * If a valid completion is found, it updates the byte_cnt parameter with the + * number of bytes transferred and increments the consumer index. + * + * @param byte_cnt Pointer to store the number of bytes transferred (-1 if no + * valid completion) + * @param params Structure containing all parameters needed for polling the CQ + */ +__host__ __device__ void cqe_poll(int32_t* byte_cnt, cqe_poll_params_t params) { + // Calculate the index in the CQ buffer + uint32_t idx = params.consumer_index; + uint32_t buffer_idx = idx & (params.cqe_cnt - 1); + + // Get the CQE at that index + uint8_t* cqe = params.cqe_buf + (buffer_idx * params.cqe_size); + + // The op_own byte is the last byte of the CQE + uint8_t op_own = cqe[params.cqe_size - 1]; + + // Extract the opcode (upper 4 bits) + uint8_t actual_opcode = op_own >> 4; + + // this only checks for valid opcode, in some case should generate error + const uint8_t FIRST_TWO_BITS_MASK = 0x3; // Binary: 00000011 + bool is_valid_opcode = (actual_opcode & ~FIRST_TWO_BITS_MASK) == 0; + + if (is_valid_opcode) { + *byte_cnt = byte_swap32(*(uint32_t*)(cqe + 44)); + + *params.dbrec = byte_swap32((idx + 1) & 0xFFFFFF); + + } else { + *byte_cnt = -1; + } +} + +/** + * @brief CUDA kernel wrapper for cqe_poll function + * + * This kernel launches a single thread to execute the cqe_poll function on the + * GPU. It includes memory fences to ensure proper ordering of memory + * operations. + * + * @param result Pointer to store the result of the poll operation + * @param params Structure containing all parameters needed for polling the CQ + */ +__global__ void cu_cqe_poll(int32_t* result, cqe_poll_params_t params) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i == 0) { + cqe_poll(result, params); + } + __syncthreads(); + __threadfence_system(); +} + +/** + * @brief Host function to launch the cu_cqe_poll kernel + * + * This function allocates memory for the result, launches the cu_cqe_poll + * kernel, and returns the result of the poll operation. + * + * @param mlx5dv_cq_void Pointer to the mlx5dv_cq structure + * @param consumer_index Current consumer index + * @return CQE_POLL_TRUE if a valid completion was found, CQE_POLL_FALSE + * otherwise, or CQE_POLL_ERROR if an error occurred + */ +cqe_poll_result_t launch_cqe_poll(void* mlx5dv_cq_void, int consumer_index) { + // Cast to proper types on CPU side + struct mlx5dv_cq* cq = (struct mlx5dv_cq*)mlx5dv_cq_void; + + // Allocate memory for result + int32_t* byte_cnt = nullptr; + cudaError_t err = cudaMallocManaged(&byte_cnt, sizeof(int32_t)); + if (err != cudaSuccess) { + return CQE_POLL_ERROR; + } + *byte_cnt = -1; // Initialize to false + + // Create the parameters struct + cqe_poll_params_t params = { + .cqe_buf = (uint8_t*)cq->buf, + .cqe_size = cq->cqe_size, + .consumer_index = (uint32_t)consumer_index, + .cqe_cnt = cq->cqe_cnt, + .dbrec = (uint32_t*)cq->dbrec}; + + // Launch the kernel with the parameters struct + cu_cqe_poll<<<1, 1>>>(byte_cnt, params); + + // Synchronize and get result + cudaDeviceSynchronize(); + + // Check for errors + err = cudaGetLastError(); + if (err != cudaSuccess) { + cudaFree(byte_cnt); + return CQE_POLL_ERROR; + } + + // Get the result + cqe_poll_result_t ret_val = *byte_cnt >= 0 ? CQE_POLL_TRUE : CQE_POLL_FALSE; + cudaFree(byte_cnt); + return ret_val; +} + +/** + * @brief Function to poll send completion queue + * + * This is a wrapper around launch_cqe_poll specifically for send completions. + * + * @param mlx5dv_cq_void Pointer to the mlx5dv_cq structure for the send CQ + * @param consumer_index Current consumer index + * @return CQE_POLL_TRUE if a valid completion was found, CQE_POLL_FALSE + * otherwise, or CQE_POLL_ERROR if an error occurred + */ +cqe_poll_result_t launch_send_cqe_poll( + void* mlx5dv_cq_void, + int consumer_index) { + return launch_cqe_poll(mlx5dv_cq_void, consumer_index); +} + +/** + * @brief Function to poll receive completion queue + * + * This is a wrapper around launch_cqe_poll specifically for receive + * completions. + * + * @param mlx5dv_cq_void Pointer to the mlx5dv_cq structure for the receive CQ + * @param consumer_index Current consumer index + * @return CQE_POLL_TRUE if a valid completion was found, CQE_POLL_FALSE + * otherwise, or CQE_POLL_ERROR if an error occurred + */ +cqe_poll_result_t launch_recv_cqe_poll( + void* mlx5dv_cq_void, + int consumer_index) { + return launch_cqe_poll(mlx5dv_cq_void, consumer_index); +} diff --git a/rdmaxcel-sys/src/rdmaxcel.h b/rdmaxcel-sys/src/rdmaxcel.h new file mode 100644 index 000000000..e3d22ce9a --- /dev/null +++ b/rdmaxcel-sys/src/rdmaxcel.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef RDMAXCEL_H +#define RDMAXCEL_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { + +#endif + +typedef enum { + CQE_POLL_ERROR = -1, + CQE_POLL_FALSE = 0, + CQE_POLL_TRUE = 1 +} cqe_poll_result_t; + +// Structure for WQE parameters +typedef struct { + uintptr_t laddr; + uint32_t lkey; + size_t length; + uint64_t wr_id; + bool signaled; + uint32_t op_type; // MLX5_OPCODE_* + uintptr_t raddr; + uint32_t rkey; + uint32_t qp_num; + uint8_t* buf; + uint32_t* dbrec; + uint32_t wqe_cnt; +} wqe_params_t; + +// Structure for CQE poll parameters +typedef struct { + uint8_t* cqe_buf; // CQE buffer address (mlx5dv_cq->buf) + uint32_t cqe_size; // Size of each CQE (mlx5dv_cq->cqe_size) + uint32_t consumer_index; // Current consumer index + uint32_t cqe_cnt; // Total number of CQEs (mlx5dv_cq->cqe_cnt) + uint32_t* dbrec; // Doorbell record (mlx5dv_cq->dbrec) +} cqe_poll_params_t; + +struct ibv_qp* create_qp( + struct ibv_context* context, + struct ibv_pd* pd, + int cq_entries, + int max_send_wr, + int max_recv_wr, + int max_send_sge, + int max_recv_sge); + +struct mlx5dv_qp* create_mlx5dv_qp(struct ibv_qp* qp); + +struct mlx5dv_cq* create_mlx5dv_cq(struct ibv_qp* qp); +struct mlx5dv_cq* create_mlx5dv_send_cq(struct ibv_qp* qp); +struct mlx5dv_cq* create_mlx5dv_recv_cq(struct ibv_qp* qp); + +cudaError_t register_cuda_memory( + struct mlx5dv_qp* dv_qp, + struct mlx5dv_cq* dv_recv_cq, + struct mlx5dv_cq* dv_send_cq); + +// Function that can be called from both host and device code +__host__ __device__ void db_ring(void* dst, void* src); + +__global__ void cu_db_ring(void* dst, void* src); + +// Host function to launch the cu_db_ring kernel +void launch_db_ring(void* dst, void* src); + +cqe_poll_result_t launch_cqe_poll(void* mlx5dv_cq, int32_t cqe_idx); +cqe_poll_result_t launch_send_cqe_poll(void* mlx5dv_cq, int32_t cqe_idx); +cqe_poll_result_t launch_recv_cqe_poll(void* mlx5dv_cq, int32_t cqe_idx); + +__global__ void cu_cqe_poll(int32_t* result, cqe_poll_params_t params); + +__host__ __device__ void cqe_poll(int32_t* result, cqe_poll_params_t params); + +// Function that can be called from both host and device code for posting WQEs +__host__ __device__ void send_wqe(wqe_params_t params); +__host__ __device__ void recv_wqe(wqe_params_t params); + +// CUDA kernel that calls send_wqe on the device +__global__ void cu_send_wqe(wqe_params_t params); +__global__ void cu_recv_wqe(wqe_params_t params); + +// Host function to launch the cu_send_wqe kernel +void launch_send_wqe(wqe_params_t params); +void launch_recv_wqe(wqe_params_t params); + +#ifdef __cplusplus +} +#endif + +#endif // RDMAXCEL_H diff --git a/rdmacore-sys/src/wrapper.h b/rdmaxcel-sys/src/test_rdmaxcel.c similarity index 57% rename from rdmacore-sys/src/wrapper.h rename to rdmaxcel-sys/src/test_rdmaxcel.c index e4bd77f03..6f6182e7b 100644 --- a/rdmacore-sys/src/wrapper.h +++ b/rdmaxcel-sys/src/test_rdmaxcel.c @@ -6,5 +6,11 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include +#include +#include "rdmaxcel.h" + +int main() { + void* func_ptr = (void*)&cu_db_ring; + printf("cu_db_ring function address: %p\n", func_ptr); + return 0; +}