Skip to content

Commit c8d7349

Browse files
committed
it works. jay!
1 parent c478fbe commit c8d7349

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,24 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
177177
}
178178
}
179179

180+
fn memset(
181+
&mut self,
182+
ptr: &'ll Value,
183+
fill_byte: &'ll Value,
184+
size: &'ll Value,
185+
align: Align,
186+
) {
187+
unsafe {
188+
llvm::LLVMRustBuildMemSet(
189+
self.llbuilder,
190+
ptr,
191+
align.bytes() as c_uint,
192+
fill_byte,
193+
size,
194+
false,
195+
);
196+
}
197+
}
180198
}
181199

182200
/// Empty string, to be used where LLVM expects an instruction name, indicating

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ fn create_struct_ty<'ll>(
2525
}
2626
}
2727

28-
//weak_odr hidden local_unnamed_addr addrspace(1) constant i32 0
2928
pub(crate) fn gen_asdf<'ll>(cgcx: &CodegenContext<LlvmCodegenBackend>, old_cx: &SimpleCx<'ll>) {
3029
let llcx = unsafe { llvm::LLVMRustContextCreate(false) };
3130
let module_name = CString::new("offload.wrapper.module").unwrap();
@@ -236,7 +235,7 @@ pub(crate) fn handle_gpu_code<'ll>(
236235
cx: &'ll SimpleCx<'_>,
237236
) {
238237
if cx.get_function("gen_tgt_offload").is_some() {
239-
let (offload_entry_ty, at_one, begin, update, end, fn_ty) = gen_globals(&cx);
238+
let (offload_entry_ty, at_one, begin, update, end, tgt_bin_desc, fn_ty) = gen_globals(&cx);
240239

241240
dbg!("created struct");
242241
let mut o_types = vec![];
@@ -249,7 +248,7 @@ pub(crate) fn handle_gpu_code<'ll>(
249248
}
250249
}
251250
dbg!("gen_call_handling");
252-
gen_call_handling(&cx, &kernels, at_one, begin, update, end, fn_ty, &o_types);
251+
gen_call_handling(&cx, &kernels, at_one, begin, update, end, tgt_bin_desc, fn_ty, &o_types);
253252
gen_image_wrapper_module(&cgcx, &cx);
254253
gen_asdf(&cgcx, &cx);
255254
} else {
@@ -279,6 +278,7 @@ fn gen_globals<'ll>(
279278
&'ll llvm::Value,
280279
&'ll llvm::Value,
281280
&'ll llvm::Type,
281+
&'ll llvm::Type,
282282
) {
283283
let offload_entry_ty = add_tgt_offload_entry(&cx);
284284
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
@@ -312,6 +312,11 @@ fn gen_globals<'ll>(
312312
let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
313313
llvm::set_alignment(at_one, Align::EIGHT);
314314

315+
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
316+
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
317+
let tgt_bin_desc_name = cx.type_named_struct("struct.__tgt_bin_desc");
318+
cx.set_struct_body(tgt_bin_desc_name, &tgt_bin_desc_ty, false);
319+
315320
// coppied from LLVM
316321
// typedef struct {
317322
// uint64_t Reserved;
@@ -379,7 +384,7 @@ fn gen_globals<'ll>(
379384
attributes::apply_to_llfn(bar, Function, &[nounwind]);
380385
attributes::apply_to_llfn(baz, Function, &[nounwind]);
381386

382-
(offload_entry_ty, at_one, foo, bar, baz, mapper_fn_ty)
387+
(offload_entry_ty, at_one, foo, bar, baz, tgt_bin_desc_name, mapper_fn_ty)
383388
}
384389

385390
fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
@@ -561,6 +566,7 @@ fn gen_call_handling<'ll>(
561566
begin: &'ll llvm::Value,
562567
update: &'ll llvm::Value,
563568
end: &'ll llvm::Value,
569+
tgt_bin_desc: &'ll llvm::Type,
564570
fn_ty: &'ll llvm::Type,
565571
o_types: &[&'ll llvm::Value],
566572
) {
@@ -586,7 +592,18 @@ fn gen_call_handling<'ll>(
586592
let mut names: Vec<&llvm::Value> = Vec::with_capacity(num_args as usize);
587593

588594
// Step 0)
595+
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
596+
// %6 = alloca %struct.__tgt_bin_desc, align 8
589597
unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
598+
599+
let tgt_bin_desc_alloca = builder.my_alloca2(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
600+
//fill_byte: &'ll Value,
601+
//size: &'ll Value,
602+
//align: Align,
603+
//flags: MemFlags,
604+
// call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false)
605+
// mem
606+
590607
let ty = cx.type_array(cx.type_ptr(), num_args);
591608
// Baseptr are just the input pointer to the kernel, stored in a local alloca
592609
let a1 = builder.my_alloca2(ty, Align::EIGHT, ".offload_baseptrs");
@@ -616,6 +633,46 @@ fn gen_call_handling<'ll>(
616633

617634
// Step 1)
618635
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
636+
builder.memset(
637+
tgt_bin_desc_alloca,
638+
cx.get_const_i8(0),
639+
cx.get_const_i64(32),
640+
Align::from_bytes(8).unwrap(),
641+
);
642+
643+
let tptr = cx.type_ptr();
644+
let mapper_fn_ty = cx.type_func(&[tptr], cx.type_void());
645+
let foo = crate::declare::declare_simple_fn(
646+
&cx,
647+
&"__tgt_register_lib",
648+
llvm::CallConv::CCallConv,
649+
llvm::UnnamedAddr::No,
650+
llvm::Visibility::Default,
651+
mapper_fn_ty,
652+
);
653+
let bar = crate::declare::declare_simple_fn(
654+
&cx,
655+
&"__tgt_unregister_lib",
656+
llvm::CallConv::CCallConv,
657+
llvm::UnnamedAddr::No,
658+
llvm::Visibility::Default,
659+
mapper_fn_ty,
660+
);
661+
let init_ty = cx.type_func(&[], cx.type_void());
662+
let baz = crate::declare::declare_simple_fn(
663+
&cx,
664+
&"__tgt_init_all_rtls",
665+
llvm::CallConv::CCallConv,
666+
llvm::UnnamedAddr::No,
667+
llvm::Visibility::Default,
668+
init_ty,
669+
);
670+
671+
builder.call(mapper_fn_ty, foo, &[tgt_bin_desc_alloca], None);
672+
builder.call(init_ty, baz, &[], None);
673+
674+
// call void @__tgt_register_lib(ptr noundef %6)
675+
// call void @__tgt_init_all_rtls()
619676
for i in 0..num_args {
620677
let idx = cx.get_const_i32(i);
621678
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
@@ -667,6 +724,7 @@ fn gen_call_handling<'ll>(
667724
nullptr,
668725
];
669726
builder.call(fn_ty, end, &args, None);
727+
builder.call(mapper_fn_ty, bar, &[tgt_bin_desc_alloca], None);
670728

671729
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
672730
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)

0 commit comments

Comments
 (0)