Skip to content

Commit f12ac2b

Browse files
committed
first definition of offload intrinsic (dirty code)
1 parent cf8346d commit f12ac2b

File tree

7 files changed

+210
-45
lines changed

7 files changed

+210
-45
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@ use llvm::Linkage::*;
44
use rustc_abi::Align;
55
use rustc_codegen_ssa::back::write::CodegenContext;
66
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7+
use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
78

89
use crate::builder::SBuilder;
9-
use crate::common::AsCCharPtr;
1010
use crate::llvm::AttributePlace::Function;
11-
use crate::llvm::{self, Linkage, Type, Value};
11+
use crate::llvm::{self, BasicBlock, Linkage, Type, Value};
1212
use crate::{LlvmCodegenBackend, SimpleCx, attributes};
1313

1414
pub(crate) fn handle_gpu_code<'ll>(
1515
_cgcx: &CodegenContext<LlvmCodegenBackend>,
16-
cx: &'ll SimpleCx<'_>,
16+
_cx: &'ll SimpleCx<'_>,
1717
) {
18+
/*
1819
// The offload memory transfer type for each kernel
1920
let mut memtransfer_types = vec![];
2021
let mut region_ids = vec![];
@@ -29,6 +30,7 @@ pub(crate) fn handle_gpu_code<'ll>(
2930
}
3031
3132
gen_call_handling(&cx, &memtransfer_types, &region_ids);
33+
*/
3234
}
3335

3436
// ; Function Attrs: nounwind
@@ -76,7 +78,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7678
at_one
7779
}
7880

79-
struct TgtOffloadEntry {
81+
pub(crate) struct TgtOffloadEntry {
8082
// uint64_t Reserved;
8183
// uint16_t Version;
8284
// uint16_t Kind;
@@ -253,11 +255,14 @@ pub(crate) fn add_global<'ll>(
253255
// This function returns a memtransfer value which encodes how arguments to this kernel shall be
254256
// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
255257
// concatenated into the list of region_ids.
256-
fn gen_define_handling<'ll>(
257-
cx: &'ll SimpleCx<'_>,
258+
pub(crate) fn gen_define_handling<'ll, 'tcx>(
259+
cx: &SimpleCx<'ll>,
260+
tcx: TyCtxt<'tcx>,
258261
kernel: &'ll llvm::Value,
259262
offload_entry_ty: &'ll llvm::Type,
260-
num: i64,
263+
// TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need
264+
tt: Vec<Ty<'tcx>>,
265+
symbol: &str,
261266
) -> (&'ll llvm::Value, &'ll llvm::Value) {
262267
let types = cx.func_params_types(cx.get_type_of_global(kernel));
263268
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
@@ -267,37 +272,50 @@ fn gen_define_handling<'ll>(
267272
.filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
268273
.count();
269274

275+
// TODO(Sa4dUs): Add typetrees here
276+
let ptr_sizes = types
277+
.iter()
278+
.zip(tt)
279+
.filter_map(|(&x, ty)| match cx.type_kind(x) {
280+
rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)),
281+
_ => None,
282+
})
283+
.collect::<Vec<u64>>();
284+
270285
// We do not know their size anymore at this level, so hardcode a placeholder.
271286
// A follow-up pr will track these from the frontend, where we still have Rust types.
272287
// Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
273288
// I decided that 1024 bytes is a great placeholder value for now.
274-
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
289+
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
275290
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
276291
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
277292
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
278293
// will be 2. For now, everything is 3, until we have our frontend set up.
279294
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
280295
let memtransfer_types = add_priv_unnamed_arr(
281296
&cx,
282-
&format!(".offload_maptypes.{num}"),
297+
&format!(".offload_maptypes.{symbol}"),
283298
&vec![1 + 2 + 32; num_ptr_types],
284299
);
300+
285301
// Next: For each function, generate these three entries. A weak constant,
286302
// the llvm.rodata entry name, and the llvm_offload_entries value
287303

288-
let name = format!(".kernel_{num}.region_id");
304+
let name = format!(".{symbol}.region_id");
289305
let initializer = cx.get_const_i8(0);
290306
let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
291307

292-
let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
308+
let c_entry_name = CString::new(symbol).unwrap();
293309
let c_val = c_entry_name.as_bytes_with_nul();
294-
let offload_entry_name = format!(".offloading.entry_name.{num}");
310+
let offload_entry_name = format!(".offloading.entry_name.{symbol}");
295311

296312
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
297313
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
298314
llvm::set_alignment(llglobal, Align::ONE);
299315
llvm::set_section(llglobal, c".llvm.rodata.offloading");
300-
let name = format!(".offloading.entry.kernel_{num}");
316+
317+
// Not actively used yet, for calling real kernels
318+
let name = format!(".offloading.entry.{symbol}");
301319

302320
// See the __tgt_offload_entry documentation above.
303321
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
@@ -314,7 +332,57 @@ fn gen_define_handling<'ll>(
314332
(memtransfer_types, region_id)
315333
}
316334

317-
pub(crate) fn declare_offload_fn<'ll>(
335+
// TODO(Sa4dUs): move this to a proper place
336+
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
337+
match ty.kind() {
338+
/*
339+
rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(),
340+
rustc_middle::infer::canonical::ir::TyKind::Char => todo!(),
341+
rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(),
342+
rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(),
343+
rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(),
344+
rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(),
345+
rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(),
346+
rustc_middle::infer::canonical::ir::TyKind::Str => todo!(),
347+
rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(),
348+
rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(),
349+
rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(),
350+
rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(),
351+
*/
352+
ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
353+
/*
354+
rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(),
355+
rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(),
356+
rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(),
357+
rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(),
358+
rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(),
359+
rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(),
360+
rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(),
361+
rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(),
362+
rustc_middle::infer::canonical::ir::TyKind::Never => todo!(),
363+
rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(),
364+
rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(),
365+
rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(),
366+
rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(),
367+
rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(),
368+
rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(),
369+
rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(),
370+
*/
371+
_ => {
372+
tcx
373+
// TODO(Sa4dUs): Maybe `.as_query_input()`?
374+
.layout_of(PseudoCanonicalInput {
375+
typing_env: TypingEnv::fully_monomorphized(),
376+
value: ty,
377+
})
378+
.unwrap()
379+
.size
380+
.bytes()
381+
}
382+
}
383+
}
384+
385+
fn declare_offload_fn<'ll>(
318386
cx: &'ll SimpleCx<'_>,
319387
name: &str,
320388
ty: &'ll llvm::Type,
@@ -349,10 +417,13 @@ pub(crate) fn declare_offload_fn<'ll>(
349417
// 4. set insert point after kernel call.
350418
// 5. generate all the GEPS and stores, to be used in 6)
351419
// 6. generate __tgt_target_data_end calls to move data from the GPU
352-
fn gen_call_handling<'ll>(
353-
cx: &'ll SimpleCx<'_>,
420+
pub(crate) fn gen_call_handling<'ll>(
421+
cx: &SimpleCx<'ll>,
422+
bb: &BasicBlock,
423+
kernels: &[&'ll llvm::Value],
354424
memtransfer_types: &[&'ll llvm::Value],
355425
region_ids: &[&'ll llvm::Value],
426+
llfn: &'ll Value,
356427
) {
357428
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
358429
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
@@ -365,27 +436,14 @@ fn gen_call_handling<'ll>(
365436
let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
366437
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
367438

368-
let main_fn = cx.get_function("main");
369-
let Some(main_fn) = main_fn else { return };
370-
let kernel_name = "kernel_1";
371-
let call = unsafe {
372-
llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
373-
};
374-
let Some(kernel_call) = call else {
375-
return;
376-
};
377-
let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
378-
let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
379-
let mut builder = SBuilder::build(cx, kernel_call_bb);
380-
381-
let types = cx.func_params_types(cx.get_type_of_global(called));
439+
let mut builder = SBuilder::build(cx, bb);
440+
441+
let types = cx.func_params_types(cx.get_type_of_global(kernels[0]));
382442
let num_args = types.len() as u64;
383443

384444
// Step 0)
385445
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
386446
// %6 = alloca %struct.__tgt_bin_desc, align 8
387-
unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
388-
389447
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
390448

391449
let ty = cx.type_array(cx.type_ptr(), num_args);
@@ -401,15 +459,14 @@ fn gen_call_handling<'ll>(
401459
let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
402460

403461
// Step 1)
404-
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
405462
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
406463

407464
// Now we allocate once per function param, a copy to be passed to one of our maps.
408465
let mut vals = vec![];
409466
let mut geps = vec![];
410467
let i32_0 = cx.get_const_i32(0);
411-
for index in 0..types.len() {
412-
let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() };
468+
for index in 0..num_args {
469+
let v = unsafe { llvm::LLVMGetParam(llfn, index as u32) };
413470
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
414471
vals.push(v);
415472
geps.push(gep);
@@ -501,13 +558,8 @@ fn gen_call_handling<'ll>(
501558
region_ids[0],
502559
a5,
503560
];
504-
let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
561+
builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
505562
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
506-
unsafe {
507-
let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
508-
llvm::LLVMRustPositionAfter(builder.llbuilder, next);
509-
llvm::LLVMInstructionEraseFromParent(next);
510-
}
511563

512564
// Step 4)
513565
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
@@ -516,8 +568,4 @@ fn gen_call_handling<'ll>(
516568
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
517569

518570
drop(builder);
519-
// FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
520-
// and then delete the call to the CPU version. In the future, we should use an intrinsic which
521-
// directly resolves to a call to the GPU version.
522-
unsafe { llvm::LLVMDeleteFunction(called) };
523571
}

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use tracing::debug;
2323
use crate::abi::FnAbiLlvmExt;
2424
use crate::builder::Builder;
2525
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
26+
use crate::builder::gpu_offload::TgtOffloadEntry;
2627
use crate::context::CodegenCx;
2728
use crate::errors::AutoDiffWithoutEnable;
2829
use crate::llvm::{self, Metadata, Type, Value};
@@ -195,6 +196,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
195196
codegen_autodiff(self, tcx, instance, args, result);
196197
return Ok(());
197198
}
199+
sym::offload => {
200+
codegen_offload(self, tcx, instance, args, result);
201+
return Ok(());
202+
}
198203
sym::is_val_statically_known => {
199204
if let OperandValue::Immediate(imm) = args[0].val {
200205
self.call_intrinsic(
@@ -1227,6 +1232,72 @@ fn codegen_autodiff<'ll, 'tcx>(
12271232
);
12281233
}
12291234

1235+
fn codegen_offload<'ll, 'tcx>(
1236+
bx: &mut Builder<'_, 'll, 'tcx>,
1237+
tcx: TyCtxt<'tcx>,
1238+
instance: ty::Instance<'tcx>,
1239+
_args: &[OperandRef<'tcx, &'ll Value>],
1240+
_result: PlaceRef<'tcx, &'ll Value>,
1241+
) {
1242+
let cx = bx.cx;
1243+
let fn_args = instance.args;
1244+
1245+
let (target_id, target_args) = match fn_args.into_type_list(tcx)[0].kind() {
1246+
ty::FnDef(def_id, params) => (def_id, params),
1247+
_ => bug!("invalid offload intrinsic arg"),
1248+
};
1249+
1250+
let fn_target = match Instance::try_resolve(tcx, cx.typing_env(), *target_id, target_args) {
1251+
Ok(Some(instance)) => instance,
1252+
Ok(None) => bug!(
1253+
"could not resolve ({:?}, {:?}) to a specific offload instance",
1254+
target_id,
1255+
target_args
1256+
),
1257+
Err(_) => {
1258+
// An error has already been emitted
1259+
return;
1260+
}
1261+
};
1262+
1263+
// TODO(Sa4dUs): Will need typetrees
1264+
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE);
1265+
let Some(kernel) = cx.get_function(&target_symbol) else {
1266+
bug!("could not find target function")
1267+
};
1268+
1269+
let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
1270+
1271+
// Build TypeTree (or something similar)
1272+
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
1273+
let inputs = sig.inputs();
1274+
1275+
// TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory
1276+
let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling(
1277+
cx,
1278+
tcx,
1279+
kernel,
1280+
offload_entry_ty,
1281+
inputs.to_vec(),
1282+
&target_symbol,
1283+
);
1284+
1285+
let kernels = &[kernel];
1286+
1287+
let llfn = bx.llfn();
1288+
1289+
// TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix
1290+
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
1291+
crate::builder::gpu_offload::gen_call_handling(
1292+
cx,
1293+
bb,
1294+
kernels,
1295+
&[memtransfer_type],
1296+
&[region_id],
1297+
llfn,
1298+
);
1299+
}
1300+
12301301
fn get_args_from_tuple<'ll, 'tcx>(
12311302
bx: &mut Builder<'_, 'll, 'tcx>,
12321303
tuple_op: OperandRef<'tcx, &'ll Value>,

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
//!
55
//! This API is completely unstable and subject to change.
66
7+
// TODO(Sa4dUs): remove this once we have a great version, just to ignore unused LLVM wrappers
8+
#![allow(unused)]
79
// tidy-alphabetical-start
810
#![allow(internal_features)]
911
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
163163
| sym::minnumf128
164164
| sym::mul_with_overflow
165165
| sym::needs_drop
166+
| sym::offload
166167
| sym::powf16
167168
| sym::powf32
168169
| sym::powf64
@@ -310,6 +311,7 @@ pub(crate) fn check_intrinsic_type(
310311
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
311312
(0, 0, vec![type_id, type_id], tcx.types.bool)
312313
}
314+
sym::offload => (2, 0, vec![param(0)], param(1)),
313315
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
314316
sym::arith_offset => (
315317
1,

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,7 @@ symbols! {
15581558
object_safe_for_dispatch,
15591559
of,
15601560
off,
1561+
offload,
15611562
offset,
15621563
offset_of,
15631564
offset_of_enum,

library/core/src/intrinsics/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3262,6 +3262,10 @@ pub const fn copysignf128(x: f128, y: f128) -> f128;
32623262
#[rustc_intrinsic]
32633263
pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
32643264

3265+
#[rustc_nounwind]
3266+
#[rustc_intrinsic]
3267+
pub const fn offload<F, R>(f: F) -> R;
3268+
32653269
/// Inform Miri that a given pointer definitely has a certain alignment.
32663270
#[cfg(miri)]
32673271
#[rustc_allow_const_fn_unstable(const_eval_select)]

0 commit comments

Comments
 (0)