Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::ffi::CString;

use bitflags::Flags;
use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::bug;
use rustc_middle::ty::offload_meta::OffloadMetadata;
use rustc_middle::ty::offload_meta::{MappingFlags, OffloadMetadata};

use crate::builder::Builder;
use crate::common::CodegenCx;
Expand Down Expand Up @@ -292,7 +293,9 @@ impl KernelArgsTy {
#[derive(Copy, Clone)]
pub(crate) struct OffloadKernelGlobals<'ll> {
pub offload_sizes: &'ll llvm::Value,
pub memtransfer_types: &'ll llvm::Value,
pub memtransfer_begin: &'ll llvm::Value,
pub memtransfer_kernel: &'ll llvm::Value,
pub memtransfer_end: &'ll llvm::Value,
pub region_id: &'ll llvm::Value,
pub offload_entry: &'ll llvm::Value,
}
Expand Down Expand Up @@ -371,18 +374,38 @@ pub(crate) fn gen_define_handling<'ll>(

let offload_entry_ty = offload_globals.offload_entry_ty;

// FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
let (sizes, transfer): (Vec<_>, Vec<_>) =
metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
metadata.iter().map(|m| (m.payload_size, m.mode)).unzip();
// Our begin mapper should only see simplified information about which args have to be
// transferred to the device, the end mapper only about which args should be transferred back.
// Any information beyond that makes it harder for LLVM's opt pass to evaluate whether it can
// safely move (=optimize) the LLVM-IR location of this data transfer. Only the mapping types
// mentioned below are handled, so make sure that we don't generate any other ones.
let handled_mappings = MappingFlags::TO
| MappingFlags::FROM
| MappingFlags::TARGET_PARAM
| MappingFlags::LITERAL
| MappingFlags::IMPLICIT;
for arg in &transfer {
debug_assert!(!arg.contains_unknown_bits());
debug_assert!(arg.difference(handled_mappings).is_empty());
}

let valid_begin_mappings = MappingFlags::TO | MappingFlags::LITERAL | MappingFlags::IMPLICIT;
let transfer_to: Vec<u64> =
transfer.clone().iter().map(|m| m.intersection(valid_begin_mappings).bits()).collect();
let transfer_from: Vec<u64> =
transfer.iter().map(|m| m.intersection(MappingFlags::FROM).bits()).collect();
// FIXME(offload): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
let transfer_kernel = vec![MappingFlags::TARGET_PARAM.bits(); transfer_to.len()];

let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &sizes);
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
// will be 2. For now, everything is 3, until we have our frontend set up.
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
let memtransfer_types =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &transfer);
let memtransfer_begin =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.begin"), &transfer_to);
let memtransfer_kernel =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.kernel"), &transfer_kernel);
let memtransfer_end =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.end"), &transfer_from);

// Next: For each function, generate these three entries. A weak constant,
// the llvm.rodata entry name, and the llvm_offload_entries value
Expand Down Expand Up @@ -415,8 +438,14 @@ pub(crate) fn gen_define_handling<'ll>(
let c_section_name = CString::new("llvm_offload_entries").unwrap();
llvm::set_section(offload_entry, &c_section_name);

let result =
OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };
let result = OffloadKernelGlobals {
offload_sizes,
memtransfer_begin,
memtransfer_kernel,
memtransfer_end,
region_id,
offload_entry,
};

cx.offload_kernel_cache.borrow_mut().insert(symbol, result);

Expand Down Expand Up @@ -479,8 +508,14 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
offload_dims: &OffloadKernelDims<'ll>,
) {
let cx = builder.cx;
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
offload_data;
let OffloadKernelGlobals {
offload_sizes,
offload_entry,
memtransfer_begin,
memtransfer_kernel,
memtransfer_end,
region_id,
} = offload_data;
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
offload_dims;

Expand Down Expand Up @@ -640,14 +675,14 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
generate_mapper_call(
builder,
geps,
memtransfer_types,
memtransfer_begin,
begin_mapper_decl,
fn_ty,
num_args,
s_ident_t,
);
let values =
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);
KernelArgsTy::new(&cx, num_args, memtransfer_kernel, geps, workgroup_dims, thread_dims);

// Step 3)
// Here we fill the KernelArgsTy, see the documentation above
Expand All @@ -673,7 +708,7 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
generate_mapper_call(
builder,
geps,
memtransfer_types,
memtransfer_end,
end_mapper_decl,
fn_ty,
num_args,
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen-llvm/gpu_offload/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
// CHECK: br label %bb3
// CHECK-NOT define
// CHECK: bb3
// CHECK: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null)
// CHECK: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo.begin, ptr null, ptr null)
// CHECK: %10 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.foo.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo.end, ptr null, ptr null)
#[unsafe(no_mangle)]
unsafe fn main() {
let A = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
Expand Down
89 changes: 50 additions & 39 deletions tests/codegen-llvm/gpu_offload/gpu_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@
#[unsafe(no_mangle)]
fn main() {
let mut x = [3.0; 256];
kernel_1(&mut x);
let y = [1.0; 256];
kernel_1(&mut x, &y);
core::hint::black_box(&x);
core::hint::black_box(&y);
}

#[unsafe(no_mangle)]
#[inline(never)]
pub fn kernel_1(x: &mut [f32; 256]) {
core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x,))
pub fn kernel_1(x: &mut [f32; 256], y: &[f32; 256]) {
core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x, y))
}

#[unsafe(no_mangle)]
#[inline(never)]
pub fn _kernel_1(x: &mut [f32; 256]) {
pub fn _kernel_1(x: &mut [f32; 256], y: &[f32; 256]) {
for i in 0..256 {
x[i] = 21.0;
x[i] = 21.0 + y[i];
}
}

Expand All @@ -39,8 +41,10 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
// CHECK: @anon.{{.*}}.0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
// CHECK: @anon.{{.*}}.1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @anon.{{.*}}.0 }, align 8

// CHECK: @.offload_sizes._kernel_1 = private unnamed_addr constant [1 x i64] [i64 1024]
// CHECK: @.offload_maptypes._kernel_1 = private unnamed_addr constant [1 x i64] [i64 35]
// CHECK: @.offload_sizes._kernel_1 = private unnamed_addr constant [2 x i64] [i64 1024, i64 1024]
// CHECK: @.offload_maptypes._kernel_1.begin = private unnamed_addr constant [2 x i64] [i64 1, i64 1]
// CHECK: @.offload_maptypes._kernel_1.kernel = private unnamed_addr constant [2 x i64] [i64 32, i64 32]
// CHECK: @.offload_maptypes._kernel_1.end = private unnamed_addr constant [2 x i64] [i64 2, i64 0]
// CHECK: @._kernel_1.region_id = internal constant i8 0
// CHECK: @.offloading.entry_name._kernel_1 = internal unnamed_addr constant [10 x i8] c"_kernel_1\00", section ".llvm.rodata.offloading", align 1
// CHECK: @.offloading.entry._kernel_1 = internal constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @._kernel_1.region_id, ptr @.offloading.entry_name._kernel_1, i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8
Expand All @@ -52,22 +56,23 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
// CHECK: define{{( dso_local)?}} void @main()
// CHECK-NEXT: start:
// CHECK-NEXT: %0 = alloca [8 x i8], align 8
// CHECK-NEXT: %1 = alloca [8 x i8], align 8
// CHECK-NEXT: %y = alloca [1024 x i8], align 16
// CHECK-NEXT: %x = alloca [1024 x i8], align 16
// CHECK: call void @kernel_1(ptr noalias noundef nonnull align 4 dereferenceable(1024) %x)
// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %0)
// CHECK-NEXT: store ptr %x, ptr %0, align 8
// CHECK-NEXT: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0)
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %0)
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 1024, ptr nonnull %x)
// CHECK-NEXT: ret void
// CHECK: call void @kernel_1(ptr {{.*}} %x, ptr {{.*}} %y)
// CHECK: store ptr %x, ptr %1, align 8
// CHECK: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %1)
// CHECK: store ptr %y, ptr %0, align 8
// CHECK: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0)
// CHECK: ret void
// CHECK-NEXT: }

// CHECK: define{{( dso_local)?}} void @kernel_1(ptr noalias noundef align 4 dereferenceable(1024) %x)
// CHECK: define{{( dso_local)?}} void @kernel_1(ptr noalias noundef align 4 dereferenceable(1024) %x, ptr noalias noundef readonly align 4 captures(address, read_provenance) dereferenceable(1024) %y)
// CHECK-NEXT: start:
// CHECK-NEXT: %EmptyDesc = alloca %struct.__tgt_bin_desc, align 8
// CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8
// CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8
// CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8
// CHECK-NEXT: %.offload_baseptrs = alloca [2 x ptr], align 8
// CHECK-NEXT: %.offload_ptrs = alloca [2 x ptr], align 8
// CHECK-NEXT: %.offload_sizes = alloca [2 x i64], align 8
// CHECK-NEXT: %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
// CHECK-NEXT: %dummy = load volatile ptr, ptr @.offload_sizes._kernel_1, align 8
// CHECK-NEXT: %dummy1 = load volatile ptr, ptr @.offloading.entry._kernel_1, align 8
Expand All @@ -77,32 +82,38 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
// CHECK-NEXT: store ptr %x, ptr %.offload_baseptrs, align 8
// CHECK-NEXT: store ptr %x, ptr %.offload_ptrs, align 8
// CHECK-NEXT: store i64 1024, ptr %.offload_sizes, align 8
// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1, ptr null, ptr null)
// CHECK-NEXT: [[GEP_BPTRS_0:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_baseptrs, i64 8
// CHECK-NEXT: store ptr %y, ptr [[GEP_BPTRS_0]], align 8
// CHECK-NEXT: [[GEP_PTRS_1:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_ptrs, i64 8
// CHECK-NEXT: store ptr %y, ptr [[GEP_PTRS_1]], align 8
// CHECK-NEXT: [[GEP_SIZES_1:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_sizes, i64 8
// CHECK-NEXT: store i64 1024, ptr [[GEP_SIZES_1]], align 8
// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1.begin, ptr null, ptr null)
// CHECK-NEXT: store i32 3, ptr %kernel_args, align 8
// CHECK-NEXT: %0 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 4
// CHECK-NEXT: store i32 1, ptr %0, align 4
// CHECK-NEXT: %1 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 8
// CHECK-NEXT: store ptr %.offload_baseptrs, ptr %1, align 8
// CHECK-NEXT: %2 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 16
// CHECK-NEXT: store ptr %.offload_ptrs, ptr %2, align 8
// CHECK-NEXT: %3 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 24
// CHECK-NEXT: store ptr %.offload_sizes, ptr %3, align 8
// CHECK-NEXT: %4 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 32
// CHECK-NEXT: store ptr @.offload_maptypes._kernel_1, ptr %4, align 8
// CHECK-NEXT: %5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40
// CHECK-NEXT: %6 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72
// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) %5, i8 0, i64 32, i1 false)
// CHECK-NEXT: store <4 x i32> <i32 256, i32 1, i32 1, i32 32>, ptr %6, align 8
// CHECK-NEXT: [[KARGS_OFF4:%.*]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 4
// CHECK-NEXT: store i32 2, ptr [[KARGS_OFF4]], align 4
// CHECK-NEXT: [[KARGS_OFF8:%.*]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 8
// CHECK-NEXT: store ptr %.offload_baseptrs, ptr [[KARGS_OFF8]], align 8
// CHECK-NEXT: [[KARGS_OFF16:%.*]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 16
// CHECK-NEXT: store ptr %.offload_ptrs, ptr [[KARGS_OFF16]], align 8
// CHECK-NEXT: [[KARGS_OFF24:%.*]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 24
// CHECK-NEXT: store ptr %.offload_sizes, ptr [[KARGS_OFF24]], align 8
// CHECK-NEXT: [[KARGS_OFF32:%.*]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 32
// CHECK-NEXT: store ptr @.offload_maptypes._kernel_1.kernel, ptr [[KARGS_OFF32]], align 8
// CHECK-NEXT: [[KARGS_OFF40:%.*]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40
// CHECK-NEXT: [[KARGS_OFF72:%.*]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72
// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) [[KARGS_OFF40]], i8 0, i64 32, i1 false)
// CHECK-NEXT: store <4 x i32> <i32 256, i32 1, i32 1, i32 32>, ptr [[KARGS_OFF72]], align 8
// CHECK-NEXT: %.fca.1.gep5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88
// CHECK-NEXT: store i32 1, ptr %.fca.1.gep5, align 8
// CHECK-NEXT: %.fca.2.gep7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92
// CHECK-NEXT: store i32 1, ptr %.fca.2.gep7, align 4
// CHECK-NEXT: %7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96
// CHECK-NEXT: store i32 0, ptr %7, align 8
// CHECK-NEXT: %8 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1, ptr null, ptr null)
// CHECK-NEXT: call void @__tgt_unregister_lib(ptr nonnull %EmptyDesc)
// CHECK-NEXT: ret void
// CHECK-NEXT: [[KARGS_OFF96:%.*]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96
// CHECK-NEXT: store i32 0, ptr [[KARGS_OFF96]], align 8
// CHECK-NEXT: [[TGT_RET:%.*]] = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1.end, ptr null, ptr null)
// CHECK-NEXT: call void @__tgt_unregister_lib(ptr nonnull %EmptyDesc)
// CHECK-NEXT: ret void
// CHECK-NEXT: }

// CHECK: !{i32 7, !"openmp", i32 51}
Loading