@@ -18285,65 +18285,216 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1828518285 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
1828618286 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
1828718287 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18288- case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
18288+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18289+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18290+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18291+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18292+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18293+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18294+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18295+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18296+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18297+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18298+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18299+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18300+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18301+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18302+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18303+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18304+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18305+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18306+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18307+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18308+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18309+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18310+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18311+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18312+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18313+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18314+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18315+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18316+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18317+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18318+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18319+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18320+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18321+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18322+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18323+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18324+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18325+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18326+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18327+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18328+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18329+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18330+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18331+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18332+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
1828918333
1829018334 // These operations perform a matrix multiplication and accumulation of
1829118335 // the form:
1829218336 // D = A * B + C
18293- // The return type always matches the type of matrix C.
18294- unsigned ArgForMatchingRetType;
18337+ // We need to specify one type for matrices AB and one for matrices CD.
18338+ // Sparse matrix operations can have different types for A and B as well as
18339+ // an additional type for sparsity index.
18340+ // Destination type should be put before types used for source operands.
18341+ SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
18342+ // On GFX12, the intrinsics with 16-bit accumulator use a packed layout.
18343+ // There is no need for the variable opsel argument, so always set it to
18344+ // "false".
18345+ bool AppendFalseForOpselArg = false;
1829518346 unsigned BuiltinWMMAOp;
1829618347
1829718348 switch (BuiltinID) {
1829818349 case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
1829918350 case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
18300- ArgForMatchingRetType = 2;
18351+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18352+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18353+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1830118354 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
1830218355 break;
1830318356 case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
1830418357 case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
18305- ArgForMatchingRetType = 2;
18358+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18359+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18360+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1830618361 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
1830718362 break;
18363+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18364+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18365+ AppendFalseForOpselArg = true;
18366+ LLVM_FALLTHROUGH;
1830818367 case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
1830918368 case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
18310- ArgForMatchingRetType = 2;
18369+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1831118370 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
1831218371 break;
18372+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18373+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18374+ AppendFalseForOpselArg = true;
18375+ LLVM_FALLTHROUGH;
1831318376 case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
1831418377 case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
18315- ArgForMatchingRetType = 2;
18378+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1831618379 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
1831718380 break;
1831818381 case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
1831918382 case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
18320- ArgForMatchingRetType = 2;
18383+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1832118384 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
1832218385 break;
1832318386 case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
1832418387 case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
18325- ArgForMatchingRetType = 2;
18388+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1832618389 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
1832718390 break;
1832818391 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
1832918392 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18330- ArgForMatchingRetType = 4;
18393+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18394+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18395+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
1833118396 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
1833218397 break;
1833318398 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
1833418399 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18335- ArgForMatchingRetType = 4;
18400+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18401+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18402+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
1833618403 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
1833718404 break;
18405+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18406+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18407+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18408+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
18409+ break;
18410+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18411+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18412+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18413+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
18414+ break;
18415+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18416+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18417+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18418+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
18419+ break;
18420+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18421+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18422+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18423+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
18424+ break;
18425+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18426+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18427+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18428+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
18429+ break;
18430+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18431+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18432+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18433+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
18434+ break;
18435+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18436+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18437+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18438+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
18439+ break;
18440+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18441+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18442+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18443+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
18444+ break;
18445+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18446+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18447+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18448+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
18449+ break;
18450+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18451+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18452+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18453+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
18454+ break;
18455+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18456+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18457+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18458+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
18459+ break;
18460+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18461+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18462+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18463+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
18464+ break;
18465+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18466+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18467+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18468+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
18469+ break;
18470+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18471+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18472+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18473+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
18474+ break;
18475+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18476+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18477+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18478+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
18479+ break;
18480+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18481+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
18482+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18483+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
18484+ break;
1833818485 }
1833918486
1834018487 SmallVector<Value *, 6> Args;
1834118488 for (int i = 0, e = E->getNumArgs(); i != e; ++i)
1834218489 Args.push_back(EmitScalarExpr(E->getArg(i)));
18490+ if (AppendFalseForOpselArg)
18491+ Args.push_back(Builder.getFalse());
1834318492
18344- Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
18345- {Args[ArgForMatchingRetType]->getType()});
18493+ SmallVector<llvm::Type *, 6> ArgTypes;
18494+ for (auto ArgIdx : ArgsForMatchingMatrixTypes)
18495+ ArgTypes.push_back(Args[ArgIdx]->getType());
1834618496
18497+ Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
1834718498 return Builder.CreateCall(F, Args);
1834818499 }
1834918500
0 commit comments