@@ -25,7 +25,6 @@ fn create_struct_ty<'ll>(
25
25
}
26
26
}
27
27
28
- //weak_odr hidden local_unnamed_addr addrspace(1) constant i32 0
29
28
pub ( crate ) fn gen_asdf < ' ll > ( cgcx : & CodegenContext < LlvmCodegenBackend > , old_cx : & SimpleCx < ' ll > ) {
30
29
let llcx = unsafe { llvm:: LLVMRustContextCreate ( false ) } ;
31
30
let module_name = CString :: new ( "offload.wrapper.module" ) . unwrap ( ) ;
@@ -236,7 +235,7 @@ pub(crate) fn handle_gpu_code<'ll>(
236
235
cx : & ' ll SimpleCx < ' _ > ,
237
236
) {
238
237
if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
239
- let ( offload_entry_ty, at_one, begin, update, end, fn_ty) = gen_globals ( & cx) ;
238
+ let ( offload_entry_ty, at_one, begin, update, end, tgt_bin_desc , fn_ty) = gen_globals ( & cx) ;
240
239
241
240
dbg ! ( "created struct" ) ;
242
241
let mut o_types = vec ! [ ] ;
@@ -249,7 +248,7 @@ pub(crate) fn handle_gpu_code<'ll>(
249
248
}
250
249
}
251
250
dbg ! ( "gen_call_handling" ) ;
252
- gen_call_handling ( & cx, & kernels, at_one, begin, update, end, fn_ty, & o_types) ;
251
+ gen_call_handling ( & cx, & kernels, at_one, begin, update, end, tgt_bin_desc , fn_ty, & o_types) ;
253
252
gen_image_wrapper_module ( & cgcx, & cx) ;
254
253
gen_asdf ( & cgcx, & cx) ;
255
254
} else {
@@ -279,6 +278,7 @@ fn gen_globals<'ll>(
279
278
& ' ll llvm:: Value ,
280
279
& ' ll llvm:: Value ,
281
280
& ' ll llvm:: Type ,
281
+ & ' ll llvm:: Type ,
282
282
) {
283
283
let offload_entry_ty = add_tgt_offload_entry ( & cx) ;
284
284
let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
@@ -312,6 +312,11 @@ fn gen_globals<'ll>(
312
312
let at_one = add_unnamed_global ( & cx, & "" , initializer, PrivateLinkage ) ;
313
313
llvm:: set_alignment ( at_one, Align :: EIGHT ) ;
314
314
315
+ // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
316
+ let tgt_bin_desc_ty = vec ! [ ti32, tptr, tptr, tptr] ;
317
+ let tgt_bin_desc_name = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
318
+ cx. set_struct_body ( tgt_bin_desc_name, & tgt_bin_desc_ty, false ) ;
319
+
315
320
// coppied from LLVM
316
321
// typedef struct {
317
322
// uint64_t Reserved;
@@ -379,7 +384,7 @@ fn gen_globals<'ll>(
379
384
attributes:: apply_to_llfn ( bar, Function , & [ nounwind] ) ;
380
385
attributes:: apply_to_llfn ( baz, Function , & [ nounwind] ) ;
381
386
382
- ( offload_entry_ty, at_one, foo, bar, baz, mapper_fn_ty)
387
+ ( offload_entry_ty, at_one, foo, bar, baz, tgt_bin_desc_name , mapper_fn_ty)
383
388
}
384
389
385
390
fn add_priv_unnamed_arr < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , vals : & [ u64 ] ) -> & ' ll llvm:: Value {
@@ -561,6 +566,7 @@ fn gen_call_handling<'ll>(
561
566
begin : & ' ll llvm:: Value ,
562
567
update : & ' ll llvm:: Value ,
563
568
end : & ' ll llvm:: Value ,
569
+ tgt_bin_desc : & ' ll llvm:: Type ,
564
570
fn_ty : & ' ll llvm:: Type ,
565
571
o_types : & [ & ' ll llvm:: Value ] ,
566
572
) {
@@ -586,7 +592,18 @@ fn gen_call_handling<'ll>(
586
592
let mut names: Vec < & llvm:: Value > = Vec :: with_capacity ( num_args as usize ) ;
587
593
588
594
// Step 0)
595
+ // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
596
+ // %6 = alloca %struct.__tgt_bin_desc, align 8
589
597
unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
598
+
599
+ let tgt_bin_desc_alloca = builder. my_alloca2 ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
600
+ //fill_byte: &'ll Value,
601
+ //size: &'ll Value,
602
+ //align: Align,
603
+ //flags: MemFlags,
604
+ // call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false)
605
+ // mem
606
+
590
607
let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
591
608
// Baseptr are just the input pointer to the kernel, stored in a local alloca
592
609
let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
@@ -616,6 +633,46 @@ fn gen_call_handling<'ll>(
616
633
617
634
// Step 1)
618
635
unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
636
+ builder. memset (
637
+ tgt_bin_desc_alloca,
638
+ cx. get_const_i8 ( 0 ) ,
639
+ cx. get_const_i64 ( 32 ) ,
640
+ Align :: from_bytes ( 8 ) . unwrap ( ) ,
641
+ ) ;
642
+
643
+ let tptr = cx. type_ptr ( ) ;
644
+ let mapper_fn_ty = cx. type_func ( & [ tptr] , cx. type_void ( ) ) ;
645
+ let foo = crate :: declare:: declare_simple_fn (
646
+ & cx,
647
+ & "__tgt_register_lib" ,
648
+ llvm:: CallConv :: CCallConv ,
649
+ llvm:: UnnamedAddr :: No ,
650
+ llvm:: Visibility :: Default ,
651
+ mapper_fn_ty,
652
+ ) ;
653
+ let bar = crate :: declare:: declare_simple_fn (
654
+ & cx,
655
+ & "__tgt_unregister_lib" ,
656
+ llvm:: CallConv :: CCallConv ,
657
+ llvm:: UnnamedAddr :: No ,
658
+ llvm:: Visibility :: Default ,
659
+ mapper_fn_ty,
660
+ ) ;
661
+ let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
662
+ let baz = crate :: declare:: declare_simple_fn (
663
+ & cx,
664
+ & "__tgt_init_all_rtls" ,
665
+ llvm:: CallConv :: CCallConv ,
666
+ llvm:: UnnamedAddr :: No ,
667
+ llvm:: Visibility :: Default ,
668
+ init_ty,
669
+ ) ;
670
+
671
+ builder. call ( mapper_fn_ty, foo, & [ tgt_bin_desc_alloca] , None ) ;
672
+ builder. call ( init_ty, baz, & [ ] , None ) ;
673
+
674
+ // call void @__tgt_register_lib(ptr noundef %6)
675
+ // call void @__tgt_init_all_rtls()
619
676
for i in 0 ..num_args {
620
677
let idx = cx. get_const_i32 ( i) ;
621
678
let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
@@ -667,6 +724,7 @@ fn gen_call_handling<'ll>(
667
724
nullptr,
668
725
] ;
669
726
builder. call ( fn_ty, end, & args, None ) ;
727
+ builder. call ( mapper_fn_ty, bar, & [ tgt_bin_desc_alloca] , None ) ;
670
728
671
729
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
672
730
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
0 commit comments