@@ -4,17 +4,18 @@ use llvm::Linkage::*;
44use rustc_abi:: Align ;
55use rustc_codegen_ssa:: back:: write:: CodegenContext ;
66use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
7+ use rustc_middle:: ty:: { self , PseudoCanonicalInput , Ty , TyCtxt , TypingEnv } ;
78
89use crate :: builder:: SBuilder ;
9- use crate :: common:: AsCCharPtr ;
1010use crate :: llvm:: AttributePlace :: Function ;
11- use crate :: llvm:: { self , Linkage , Type , Value } ;
11+ use crate :: llvm:: { self , BasicBlock , Linkage , Type , Value } ;
1212use crate :: { LlvmCodegenBackend , SimpleCx , attributes} ;
1313
1414pub ( crate ) fn handle_gpu_code < ' ll > (
1515 _cgcx : & CodegenContext < LlvmCodegenBackend > ,
16- cx : & ' ll SimpleCx < ' _ > ,
16+ _cx : & ' ll SimpleCx < ' _ > ,
1717) {
18+ /*
1819 // The offload memory transfer type for each kernel
1920 let mut memtransfer_types = vec![];
2021 let mut region_ids = vec![];
@@ -29,6 +30,7 @@ pub(crate) fn handle_gpu_code<'ll>(
2930 }
3031
3132 gen_call_handling(&cx, &memtransfer_types, ®ion_ids);
33+ */
3234}
3335
3436// ; Function Attrs: nounwind
@@ -76,7 +78,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7678 at_one
7779}
7880
79- struct TgtOffloadEntry {
81+ pub ( crate ) struct TgtOffloadEntry {
8082 // uint64_t Reserved;
8183 // uint16_t Version;
8284 // uint16_t Kind;
@@ -253,11 +255,14 @@ pub(crate) fn add_global<'ll>(
253255// This function returns a memtransfer value which encodes how arguments to this kernel shall be
254256// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
255257// concatenated into the list of region_ids.
256- fn gen_define_handling < ' ll > (
257- cx : & ' ll SimpleCx < ' _ > ,
258+ pub ( crate ) fn gen_define_handling < ' ll , ' tcx > (
259+ cx : & SimpleCx < ' ll > ,
260+ tcx : TyCtxt < ' tcx > ,
258261 kernel : & ' ll llvm:: Value ,
259262 offload_entry_ty : & ' ll llvm:: Type ,
260- num : i64 ,
263+ // TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
264+ tt : Vec < Ty < ' tcx > > ,
265+ symbol : & str ,
261266) -> ( & ' ll llvm:: Value , & ' ll llvm:: Value ) {
262267 let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
263268 // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
@@ -267,37 +272,50 @@ fn gen_define_handling<'ll>(
267272 . filter ( |& x| matches ! ( cx. type_kind( x) , rustc_codegen_ssa:: common:: TypeKind :: Pointer ) )
268273 . count ( ) ;
269274
275+ // TODO(Sa4dUs): Add typetrees here
276+ let ptr_sizes = types
277+ . iter ( )
278+ . zip ( tt)
279+ . filter_map ( |( & x, ty) | match cx. type_kind ( x) {
280+ rustc_codegen_ssa:: common:: TypeKind :: Pointer => Some ( get_payload_size ( tcx, ty) ) ,
281+ _ => None ,
282+ } )
283+ . collect :: < Vec < u64 > > ( ) ;
284+
270285 // We do not know their size anymore at this level, so hardcode a placeholder.
271286 // A follow-up pr will track these from the frontend, where we still have Rust types.
272287 // Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
273288 // I decided that 1024 bytes is a great placeholder value for now.
274- add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num }" ) , & vec ! [ 1024 ; num_ptr_types ] ) ;
289+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{symbol }" ) , & ptr_sizes ) ;
275290 // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
276291 // or both to and from the gpu (=3). Other values shouldn't affect us for now.
277292 // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
278293 // will be 2. For now, everything is 3, until we have our frontend set up.
279294 // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
280295 let memtransfer_types = add_priv_unnamed_arr (
281296 & cx,
282- & format ! ( ".offload_maptypes.{num }" ) ,
297+ & format ! ( ".offload_maptypes.{symbol }" ) ,
283298 & vec ! [ 1 + 2 + 32 ; num_ptr_types] ,
284299 ) ;
300+
285301 // Next: For each function, generate these three entries. A weak constant,
286302 // the llvm.rodata entry name, and the llvm_offload_entries value
287303
288- let name = format ! ( ".kernel_{num }.region_id" ) ;
304+ let name = format ! ( ".{symbol }.region_id" ) ;
289305 let initializer = cx. get_const_i8 ( 0 ) ;
290306 let region_id = add_unnamed_global ( & cx, & name, initializer, WeakAnyLinkage ) ;
291307
292- let c_entry_name = CString :: new ( format ! ( "kernel_{num}" ) ) . unwrap ( ) ;
308+ let c_entry_name = CString :: new ( symbol ) . unwrap ( ) ;
293309 let c_val = c_entry_name. as_bytes_with_nul ( ) ;
294- let offload_entry_name = format ! ( ".offloading.entry_name.{num }" ) ;
310+ let offload_entry_name = format ! ( ".offloading.entry_name.{symbol }" ) ;
295311
296312 let initializer = crate :: common:: bytes_in_context ( cx. llcx , c_val) ;
297313 let llglobal = add_unnamed_global ( & cx, & offload_entry_name, initializer, InternalLinkage ) ;
298314 llvm:: set_alignment ( llglobal, Align :: ONE ) ;
299315 llvm:: set_section ( llglobal, c".llvm.rodata.offloading" ) ;
300- let name = format ! ( ".offloading.entry.kernel_{num}" ) ;
316+
317+ // Not actively used yet, for calling real kernels
318+ let name = format ! ( ".offloading.entry.{symbol}" ) ;
301319
302320 // See the __tgt_offload_entry documentation above.
303321 let elems = TgtOffloadEntry :: new ( & cx, region_id, llglobal) ;
@@ -314,7 +332,57 @@ fn gen_define_handling<'ll>(
314332 ( memtransfer_types, region_id)
315333}
316334
317- pub ( crate ) fn declare_offload_fn < ' ll > (
335+ // TODO(Sa4dUs): move this to a proper place
336+ fn get_payload_size < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> u64 {
337+ match ty. kind ( ) {
338+ /*
339+ rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
340+ rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
341+ rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
342+ rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
343+ rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
344+ rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
345+ rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
346+ rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
347+ rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
348+ rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
349+ rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
350+ rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
351+ */
352+ ty:: Ref ( _, inner, _) => get_payload_size ( tcx, * inner) ,
353+ /*
354+ rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
355+ rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
356+ rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
357+ rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
358+ rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
359+ rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
360+ rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
361+ rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
362+ rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
363+ rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
364+ rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
365+ rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
366+ rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
367+ rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
368+ rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
369+ rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
370+ */
371+ _ => {
372+ tcx
373+ // TODO(Sa4dUs): Maybe `.as_query_input()`?
374+ . layout_of ( PseudoCanonicalInput {
375+ typing_env : TypingEnv :: fully_monomorphized ( ) ,
376+ value : ty,
377+ } )
378+ . unwrap ( )
379+ . size
380+ . bytes ( )
381+ }
382+ }
383+ }
384+
385+ fn declare_offload_fn < ' ll > (
318386 cx : & ' ll SimpleCx < ' _ > ,
319387 name : & str ,
320388 ty : & ' ll llvm:: Type ,
@@ -349,10 +417,13 @@ pub(crate) fn declare_offload_fn<'ll>(
349417// 4. set insert point after kernel call.
350418// 5. generate all the GEPS and stores, to be used in 6)
351419// 6. generate __tgt_target_data_end calls to move data from the GPU
352- fn gen_call_handling < ' ll > (
353- cx : & ' ll SimpleCx < ' _ > ,
420+ pub ( crate ) fn gen_call_handling < ' ll > (
421+ cx : & SimpleCx < ' ll > ,
422+ bb : & BasicBlock ,
423+ kernels : & [ & ' ll llvm:: Value ] ,
354424 memtransfer_types : & [ & ' ll llvm:: Value ] ,
355425 region_ids : & [ & ' ll llvm:: Value ] ,
426+ llfn : & ' ll Value ,
356427) {
357428 let ( tgt_decl, tgt_target_kernel_ty) = generate_launcher ( & cx) ;
358429 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
@@ -365,27 +436,14 @@ fn gen_call_handling<'ll>(
365436 let tgt_kernel_decl = KernelArgsTy :: new_decl ( & cx) ;
366437 let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
367438
368- let main_fn = cx. get_function ( "main" ) ;
369- let Some ( main_fn) = main_fn else { return } ;
370- let kernel_name = "kernel_1" ;
371- let call = unsafe {
372- llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) )
373- } ;
374- let Some ( kernel_call) = call else {
375- return ;
376- } ;
377- let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
378- let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) . unwrap ( ) } ;
379- let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
380-
381- let types = cx. func_params_types ( cx. get_type_of_global ( called) ) ;
439+ let mut builder = SBuilder :: build ( cx, bb) ;
440+
441+ let types = cx. func_params_types ( cx. get_type_of_global ( kernels[ 0 ] ) ) ;
382442 let num_args = types. len ( ) as u64 ;
383443
384444 // Step 0)
385445 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
386446 // %6 = alloca %struct.__tgt_bin_desc, align 8
387- unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
388-
389447 let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
390448
391449 let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
@@ -401,15 +459,14 @@ fn gen_call_handling<'ll>(
401459 let a5 = builder. direct_alloca ( tgt_kernel_decl, Align :: EIGHT , "kernel_args" ) ;
402460
403461 // Step 1)
404- unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
405462 builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
406463
407464 // Now we allocate once per function param, a copy to be passed to one of our maps.
408465 let mut vals = vec ! [ ] ;
409466 let mut geps = vec ! [ ] ;
410467 let i32_0 = cx. get_const_i32 ( 0 ) ;
411- for index in 0 ..types . len ( ) {
412- let v = unsafe { llvm:: LLVMGetOperand ( kernel_call , index as u32 ) . unwrap ( ) } ;
468+ for index in 0 ..num_args {
469+ let v = unsafe { llvm:: LLVMGetParam ( llfn , index as u32 ) } ;
413470 let gep = builder. inbounds_gep ( cx. type_f32 ( ) , v, & [ i32_0] ) ;
414471 vals. push ( v) ;
415472 geps. push ( gep) ;
@@ -501,13 +558,8 @@ fn gen_call_handling<'ll>(
501558 region_ids[ 0 ] ,
502559 a5,
503560 ] ;
504- let offload_success = builder. call ( tgt_target_kernel_ty, tgt_decl, & args, None ) ;
561+ builder. call ( tgt_target_kernel_ty, tgt_decl, & args, None ) ;
505562 // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
506- unsafe {
507- let next = llvm:: LLVMGetNextInstruction ( offload_success) . unwrap ( ) ;
508- llvm:: LLVMRustPositionAfter ( builder. llbuilder , next) ;
509- llvm:: LLVMInstructionEraseFromParent ( next) ;
510- }
511563
512564 // Step 4)
513565 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
@@ -516,8 +568,4 @@ fn gen_call_handling<'ll>(
516568 builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
517569
518570 drop ( builder) ;
519- // FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
520- // and then delete the call to the CPU version. In the future, we should use an intrinsic which
521- // directly resolves to a call to the GPU version.
522- unsafe { llvm:: LLVMDeleteFunction ( called) } ;
523571}
0 commit comments