Skip to content

Commit 1c1953d

Browse files
committed
gpu offload memory-transfer mvp
cleanups
1 parent fa72869 commit 1c1953d

File tree

13 files changed

+940
-5
lines changed

13 files changed

+940
-5
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ pub(crate) fn run_pass_manager(
653653
// We then run the llvm_optimize function a second time, to optimize the code which we generated
654654
// in the enzyme differentiation pass.
655655
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
656+
let enable_gpu = config.offload.contains(&config::Offload::Enable);
656657
let stage = if thin {
657658
write::AutodiffStage::PreAD
658659
} else {
@@ -667,6 +668,13 @@ pub(crate) fn run_pass_manager(
667668
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
668669
}
669670

671+
if cfg!(llvm_enzyme) && enable_gpu && !thin {
672+
dbg!(&enable_gpu);
673+
let cx =
674+
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
675+
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);
676+
}
677+
670678
if cfg!(llvm_enzyme) && enable_ad && !thin {
671679
let cx =
672680
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ops::Deref;
33
use std::{iter, ptr};
44

55
pub(crate) mod autodiff;
6+
pub(crate) mod gpu_offload;
67

78
use libc::{c_char, c_uint, size_t};
89
use rustc_abi as abi;
@@ -117,6 +118,70 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
117118
}
118119
bx
119120
}
121+
122+
pub(crate) fn my_alloca2(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value {
123+
let val = unsafe {
124+
let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED);
125+
llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint);
126+
// Cast to default addrspace if necessary
127+
llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED)
128+
};
129+
if name != "" {
130+
let name = std::ffi::CString::new(name).unwrap();
131+
llvm::set_value_name(val, &name.as_bytes());
132+
}
133+
val
134+
}
135+
136+
pub(crate) fn inbounds_gep(
137+
&mut self,
138+
ty: &'ll Type,
139+
ptr: &'ll Value,
140+
indices: &[&'ll Value],
141+
) -> &'ll Value {
142+
unsafe {
143+
llvm::LLVMBuildGEPWithNoWrapFlags(
144+
self.llbuilder,
145+
ty,
146+
ptr,
147+
indices.as_ptr(),
148+
indices.len() as c_uint,
149+
UNNAMED,
150+
GEPNoWrapFlags::InBounds,
151+
)
152+
}
153+
}
154+
155+
pub(crate) fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value {
156+
debug!("Store {:?} -> {:?}", val, ptr);
157+
assert_eq!(self.cx.type_kind(self.cx.val_ty(ptr)), TypeKind::Pointer);
158+
unsafe {
159+
let store = llvm::LLVMBuildStore(self.llbuilder, val, ptr);
160+
llvm::LLVMSetAlignment(store, align.bytes() as c_uint);
161+
store
162+
}
163+
}
164+
165+
pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value {
166+
unsafe {
167+
let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED);
168+
llvm::LLVMSetAlignment(load, align.bytes() as c_uint);
169+
load
170+
}
171+
}
172+
173+
fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) {
174+
unsafe {
175+
llvm::LLVMRustBuildMemSet(
176+
self.llbuilder,
177+
ptr,
178+
align.bytes() as c_uint,
179+
fill_byte,
180+
size,
181+
false,
182+
);
183+
}
184+
}
120185
}
121186

122187
/// Empty string, to be used where LLVM expects an instruction name, indicating

0 commit comments

Comments
 (0)