1+ // RUN: mlir-opt %s --convert-nvgpu-to-nvvm -gpu-kernel-outlining \
2+ // RUN: -convert-scf-to-cf -convert-nvvm-to-llvm \
3+ // RUN: -convert-vector-to-llvm \
4+ // RUN: -convert-math-to-llvm \
5+ // RUN: -expand-strided-metadata \
6+ // RUN: -lower-affine \
7+ // RUN: -convert-index-to-llvm=index-bitwidth=32 \
8+ // RUN: -convert-arith-to-llvm \
9+ // RUN: -finalize-memref-to-llvm \
10+ // RUN: -convert-func-to-llvm \
11+ // RUN: -canonicalize \
12+ // RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_90 features=+ptx80 dump-ptx}))' \
13+ // RUN: 2&>1 | FileCheck %s --check-prefixes=CHECK-PTX
14+
15+ // CHECK-PTX: mbarrier.init.shared.b64
16+ // CHECK-PTX: mbarrier.arrive.expect_tx.shared.b64
17+ // CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
18+ // CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
19+ // CHECK-PTX: mbarrier.arrive.expect_tx.shared.b64
20+ // CHECK-PTX: mbarrier.try_wait.parity.shared.b64
21+
22+ module @mymod {
23+ memref.global " private" @bufferLhsGlobal : memref <64 x8 xf32 , 3 >
24+ memref.global " private" @bufferRhsGlobal : memref <8 x128 xf32 , 3 >
25+ func.func @main () {
26+ %c10000000 = arith.constant 10000000 : index
27+ %c6144 = arith.constant 6144 : index
28+ %c45 = arith.constant 45 : index
29+ %c7 = arith.constant 7 : index
30+ %c64 = arith.constant 64 : index
31+ %c1 = arith.constant 1 : index
32+ %c0 = arith.constant 0 : index
33+ %c8 = arith.constant 8 : index
34+ %c128 = arith.constant 128 : index
35+ %cst = arith.constant 3.000000e+00 : f32
36+ %alloc = memref.alloc () : memref <64 x8 xf32 >
37+ %alloc_0 = memref.alloc () : memref <8 x128 xf32 >
38+ scf.for %arg0 = %c0 to %c8 step %c1 {
39+ scf.for %arg1 = %c0 to %c128 step %c1 {
40+ memref.store %cst , %alloc_0 [%arg0 , %arg1 ] : memref <8 x128 xf32 >
41+ }
42+ }
43+ scf.for %arg0 = %c0 to %c64 step %c1 {
44+ scf.for %arg1 = %c0 to %c8 step %c1 {
45+ %5 = arith.index_cast %arg1 : index to i64
46+ %6 = arith.uitofp %5 : i64 to f32
47+ memref.store %6 , %alloc [%arg0 , %arg1 ] : memref <64 x8 xf32 >
48+ }
49+ }
50+ %0 = gpu.wait async
51+ %memref , %asyncToken = gpu.alloc async [%0 ] () : memref <64 x8 xf32 >
52+ %memref_1 , %asyncToken_2 = gpu.alloc async [%0 ] () : memref <8 x128 xf32 >
53+ %1 = gpu.memcpy async [%0 ] %memref , %alloc : memref <64 x8 xf32 >, memref <64 x8 xf32 >
54+ %2 = gpu.memcpy async [%0 ] %memref_1 , %alloc_0 : memref <8 x128 xf32 >, memref <8 x128 xf32 >
55+ %cast = memref.cast %memref : memref <64 x8 xf32 > to memref <*xf32 >
56+ %cast_3 = memref.cast %memref_1 : memref <8 x128 xf32 > to memref <*xf32 >
57+ %3 = nvgpu.tma.create.descriptor %cast box [%c64 , %c8 ] : memref <*xf32 > -> <tensor = memref <64 x8 xf32 , 3 >, swizzle = none , l2promo = none , oob = zero , interleave = none >
58+ %4 = nvgpu.tma.create.descriptor %cast_3 box [%c8 , %c128 ] : memref <*xf32 > -> <tensor = memref <8 x128 xf32 , 3 >, swizzle = none , l2promo = none , oob = zero , interleave = none >
59+ gpu.launch blocks (%arg0 , %arg1 , %arg2 ) in (%arg6 = %c1 , %arg7 = %c1 , %arg8 = %c1 ) threads (%arg3 , %arg4 , %arg5 ) in (%arg9 = %c128 , %arg10 = %c1 , %arg11 = %c1 ) {
60+ %5 = gpu.block_dim x
61+ %6 = gpu.thread_id x
62+ %7 = memref.get_global @bufferLhsGlobal : memref <64 x8 xf32 , 3 >
63+ %8 = memref.get_global @bufferRhsGlobal : memref <8 x128 xf32 , 3 >
64+ %9 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space <workgroup >>
65+ nvgpu.mbarrier.init %9 , %5 : <memorySpace = #gpu.address_space <workgroup >>
66+ gpu.barrier
67+ %10 = arith.cmpi eq , %6 , %c0 : index
68+ scf.if %10 {
69+ nvgpu.mbarrier.arrive.expect_tx %9 , %c6144 : <memorySpace = #gpu.address_space <workgroup >>
70+ %11 = memref.load %7 [%c0 , %c0 ] : memref <64 x8 xf32 , 3 >
71+ %12 = memref.load %8 [%c0 , %c0 ] : memref <8 x128 xf32 , 3 >
72+ gpu.printf " [GPU] TMA BEFORE lhs[45][7] %f\0A" %11 : f32
73+ gpu.printf " [GPU] TMA BEFORE rhs[7][0] %f\0A" %12 : f32
74+ nvgpu.tma.async.load %3 [%c0 , %c0 ], %9 to %7 : <tensor = memref <64 x8 xf32 , 3 >, swizzle = none , l2promo = none , oob = zero , interleave = none >, <memorySpace = #gpu.address_space <workgroup >> -> memref <64 x8 xf32 , 3 >
75+ nvgpu.tma.async.load %4 [%c0 , %c0 ], %9 to %8 : <tensor = memref <8 x128 xf32 , 3 >, swizzle = none , l2promo = none , oob = zero , interleave = none >, <memorySpace = #gpu.address_space <workgroup >> -> memref <8 x128 xf32 , 3 >
76+ } else {
77+ nvgpu.mbarrier.arrive.expect_tx %9 , %c0 : <memorySpace = #gpu.address_space <workgroup >>
78+ }
79+ nvgpu.mbarrier.try_wait.parity %9 , %c0 , %c10000000 : <memorySpace = #gpu.address_space <workgroup >>
80+ scf.if %10 {
81+ %11 = memref.load %7 [%c45 , %c7 ] : memref <64 x8 xf32 , 3 >
82+ %12 = memref.load %8 [%c7 , %c0 ] : memref <8 x128 xf32 , 3 >
83+ gpu.printf " [GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32
84+ gpu.printf " [GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32
85+ }
86+ gpu.terminator
87+ }
88+ return
89+ }
90+ }
0 commit comments