@@ -1508,6 +1508,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
15081508 switch (I->getOpcode ()) {
15091509 default :
15101510 return false ;
1511+ case AArch64::PTRUE_C_B:
1512+ case AArch64::LD1B_2Z_IMM:
1513+ case AArch64::ST1B_2Z_IMM:
15111514 case AArch64::STR_ZXI:
15121515 case AArch64::STR_PXI:
15131516 case AArch64::LDR_ZXI:
@@ -2781,6 +2784,16 @@ struct RegPairInfo {
27812784
27822785} // end anonymous namespace
27832786
2787+ unsigned findFreePredicateReg (BitVector &SavedRegs) {
2788+ for (unsigned PReg = AArch64::P8; PReg <= AArch64::P15; ++PReg) {
2789+ if (SavedRegs.test (PReg)) {
2790+ unsigned PNReg = PReg - AArch64::P0 + AArch64::PN0;
2791+ return PNReg;
2792+ }
2793+ }
2794+ return AArch64::NoRegister;
2795+ }
2796+
27842797static void computeCalleeSaveRegisterPairs (
27852798 MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
27862799 const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
@@ -2859,7 +2872,11 @@ static void computeCalleeSaveRegisterPairs(
28592872 RPI.Reg2 = NextReg;
28602873 break ;
28612874 case RegPairInfo::PPR:
2875+ break ;
28622876 case RegPairInfo::ZPR:
2877+ if (AFI->getPredicateRegForFillSpill () != 0 )
2878+ if (((RPI.Reg1 - AArch64::Z0) & 1 ) == 0 && (NextReg == RPI.Reg1 + 1 ))
2879+ RPI.Reg2 = NextReg;
28632880 break ;
28642881 }
28652882 }
@@ -2897,14 +2914,13 @@ static void computeCalleeSaveRegisterPairs(
28972914 if (NeedsWinCFI &&
28982915 RPI.isPaired ()) // RPI.FrameIdx must be the lower index of the pair
28992916 RPI.FrameIdx = CSI[i + RegInc].getFrameIdx ();
2900-
29012917 int Scale = RPI.getScale ();
29022918
29032919 int OffsetPre = RPI.isScalable () ? ScalableByteOffset : ByteOffset;
29042920 assert (OffsetPre % Scale == 0 );
29052921
29062922 if (RPI.isScalable ())
2907- ScalableByteOffset += StackFillDir * Scale;
2923+ ScalableByteOffset += StackFillDir * (RPI. isPaired () ? 2 * Scale : Scale) ;
29082924 else
29092925 ByteOffset += StackFillDir * (RPI.isPaired () ? 2 * Scale : Scale);
29102926
@@ -2915,9 +2931,6 @@ static void computeCalleeSaveRegisterPairs(
29152931 (IsWindows && RPI.Reg2 == AArch64::LR)))
29162932 ByteOffset += StackFillDir * 8 ;
29172933
2918- assert (!(RPI.isScalable () && RPI.isPaired ()) &&
2919- " Paired spill/fill instructions don't exist for SVE vectors" );
2920-
29212934 // Round up size of non-pair to pair size if we need to pad the
29222935 // callee-save area to ensure 16-byte alignment.
29232936 if (NeedGapToAlignStack && !NeedsWinCFI &&
@@ -3004,6 +3017,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30043017 }
30053018 return true ;
30063019 }
3020+ bool PTrueCreated = false ;
30073021 for (const RegPairInfo &RPI : llvm::reverse (RegPairs)) {
30083022 unsigned Reg1 = RPI.Reg1 ;
30093023 unsigned Reg2 = RPI.Reg2 ;
@@ -3038,10 +3052,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30383052 Alignment = Align (16 );
30393053 break ;
30403054 case RegPairInfo::ZPR:
3041- StrOpc = AArch64::STR_ZXI;
3042- Size = 16 ;
3043- Alignment = Align (16 );
3044- break ;
3055+ StrOpc = RPI. isPaired () ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
3056+ Size = 16 ;
3057+ Alignment = Align (16 );
3058+ break ;
30453059 case RegPairInfo::PPR:
30463060 StrOpc = AArch64::STR_PXI;
30473061 Size = 2 ;
@@ -3065,33 +3079,79 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30653079 std::swap (Reg1, Reg2);
30663080 std::swap (FrameIdxReg1, FrameIdxReg2);
30673081 }
3068- MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3069- if (!MRI.isReserved (Reg1))
3070- MBB.addLiveIn (Reg1);
3071- if (RPI.isPaired ()) {
3082+
3083+ if (RPI.isPaired () && RPI.isScalable ()) {
3084+ const AArch64Subtarget &Subtarget = MF.getSubtarget <AArch64Subtarget>();
3085+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3086+ unsigned PnReg = AFI->getPredicateRegForFillSpill ();
3087+ assert (((Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ()) && PnReg != 0 ) &&
3088+ " Expects SVE2.1 or SME2 target and a predicate register" );
3089+ #ifdef EXPENSIVE_CHECKS
3090+ auto IsPPR = [](const RegPairInfo &c) {
3091+ return c.Reg1 == RegPairInfo::PPR;
3092+ };
3093+ auto PPRBegin = std::find_if (RegPairs.begin (), RegPairs.end (), IsPPR);
3094+ auto IsZPR = [](const RegPairInfo &c) {
3095+ return c.Type == RegPairInfo::ZPR;
3096+ };
3097+ auto ZPRBegin = std::find_if (RegPairs.begin (), RegPairs.end (), IsZPR);
3098+ assert (!(PPRBegin < ZPRBegin) &&
3099+ " Expected callee save predicate to be handled first" );
3100+ #endif
3101+ if (!PTrueCreated) {
3102+ PTrueCreated = true ;
3103+ BuildMI (MBB, MI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3104+ .setMIFlags (MachineInstr::FrameSetup);
3105+ }
3106+ MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3107+ if (!MRI.isReserved (Reg1))
3108+ MBB.addLiveIn (Reg1);
30723109 if (!MRI.isReserved (Reg2))
30733110 MBB.addLiveIn (Reg2);
3074- MIB.addReg (Reg2, getPrologueDeath (MF, Reg2 ));
3111+ MIB.addReg (/* PairRegs */ AArch64::Z0_Z1 + (RPI. Reg1 - AArch64::Z0 ));
30753112 MIB.addMemOperand (MF.getMachineMemOperand (
30763113 MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
30773114 MachineMemOperand::MOStore, Size, Alignment));
3115+ MIB.addReg (PnReg);
3116+ MIB.addReg (AArch64::SP)
3117+ .addImm (RPI.Offset ) // [sp, #offset*scale],
3118+ // where factor*scale is implicit
3119+ .setMIFlag (MachineInstr::FrameSetup);
3120+ MIB.addMemOperand (MF.getMachineMemOperand (
3121+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3122+ MachineMemOperand::MOStore, Size, Alignment));
3123+ if (NeedsWinCFI)
3124+ InsertSEH (MIB, TII, MachineInstr::FrameSetup);
3125+ } else { // The code when the pair of ZReg is not present
3126+ MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3127+ if (!MRI.isReserved (Reg1))
3128+ MBB.addLiveIn (Reg1);
3129+ if (RPI.isPaired ()) {
3130+ if (!MRI.isReserved (Reg2))
3131+ MBB.addLiveIn (Reg2);
3132+ MIB.addReg (Reg2, getPrologueDeath (MF, Reg2));
3133+ MIB.addMemOperand (MF.getMachineMemOperand (
3134+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3135+ MachineMemOperand::MOStore, Size, Alignment));
3136+ }
3137+ MIB.addReg (Reg1, getPrologueDeath (MF, Reg1))
3138+ .addReg (AArch64::SP)
3139+ .addImm (RPI.Offset ) // [sp, #offset*scale],
3140+ // where factor*scale is implicit
3141+ .setMIFlag (MachineInstr::FrameSetup);
3142+ MIB.addMemOperand (MF.getMachineMemOperand (
3143+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3144+ MachineMemOperand::MOStore, Size, Alignment));
3145+ if (NeedsWinCFI)
3146+ InsertSEH (MIB, TII, MachineInstr::FrameSetup);
30783147 }
3079- MIB.addReg (Reg1, getPrologueDeath (MF, Reg1))
3080- .addReg (AArch64::SP)
3081- .addImm (RPI.Offset ) // [sp, #offset*scale],
3082- // where factor*scale is implicit
3083- .setMIFlag (MachineInstr::FrameSetup);
3084- MIB.addMemOperand (MF.getMachineMemOperand (
3085- MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3086- MachineMemOperand::MOStore, Size, Alignment));
3087- if (NeedsWinCFI)
3088- InsertSEH (MIB, TII, MachineInstr::FrameSetup);
3089-
30903148 // Update the StackIDs of the SVE stack slots.
30913149 MachineFrameInfo &MFI = MF.getFrameInfo ();
3092- if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
3093- MFI.setStackID (RPI.FrameIdx , TargetStackID::ScalableVector);
3094-
3150+ if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
3151+ MFI.setStackID (FrameIdxReg1, TargetStackID::ScalableVector);
3152+ if (RPI.isPaired ())
3153+ MFI.setStackID (FrameIdxReg2, TargetStackID::ScalableVector);
3154+ }
30953155 }
30963156 return true ;
30973157}
@@ -3109,7 +3169,6 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31093169 DL = MBBI->getDebugLoc ();
31103170
31113171 computeCalleeSaveRegisterPairs (MF, CSI, TRI, RegPairs, hasFP (MF));
3112-
31133172 if (homogeneousPrologEpilog (MF, &MBB)) {
31143173 auto MIB = BuildMI (MBB, MBBI, DL, TII.get (AArch64::HOM_Epilog))
31153174 .setMIFlag (MachineInstr::FrameDestroy);
@@ -3130,6 +3189,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31303189 auto ZPREnd = std::find_if_not (ZPRBegin, RegPairs.end (), IsZPR);
31313190 std::reverse (ZPRBegin, ZPREnd);
31323191
3192+ bool PTrueCreated = false ;
31333193 for (const RegPairInfo &RPI : RegPairs) {
31343194 unsigned Reg1 = RPI.Reg1 ;
31353195 unsigned Reg2 = RPI.Reg2 ;
@@ -3162,7 +3222,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31623222 Alignment = Align (16 );
31633223 break ;
31643224 case RegPairInfo::ZPR:
3165- LdrOpc = AArch64::LDR_ZXI;
3225+ LdrOpc = RPI. isPaired () ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
31663226 Size = 16 ;
31673227 Alignment = Align (16 );
31683228 break ;
@@ -3187,25 +3247,58 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31873247 std::swap (Reg1, Reg2);
31883248 std::swap (FrameIdxReg1, FrameIdxReg2);
31893249 }
3190- MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3191- if (RPI.isPaired ()) {
3192- MIB.addReg (Reg2, getDefRegState (true ));
3250+
3251+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3252+ if (RPI.isPaired () && RPI.isScalable ()) {
3253+ const AArch64Subtarget &Subtarget = MF.getSubtarget <AArch64Subtarget>();
3254+ unsigned PnReg = AFI->getPredicateRegForFillSpill ();
3255+ assert (((Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ()) && PnReg != 0 ) &&
3256+ " Expects SVE2.1 or SME2 target and a predicate register" );
3257+ #ifdef EXPENSIVE_CHECKS
3258+ assert (!(PPRBegin < ZPRBegin) &&
3259+ " Expected callee save predicate to be handled first" );
3260+ #endif
3261+ if (!PTrueCreated) {
3262+ PTrueCreated = true ;
3263+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3264+ .setMIFlags (MachineInstr::FrameDestroy);
3265+ }
3266+ MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3267+ MIB.addReg (/* PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0),
3268+ getDefRegState (true ));
31933269 MIB.addMemOperand (MF.getMachineMemOperand (
31943270 MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
31953271 MachineMemOperand::MOLoad, Size, Alignment));
3272+ MIB.addReg (PnReg);
3273+ MIB.addReg (AArch64::SP)
3274+ .addImm (RPI.Offset ) // [sp, #offset*scale]
3275+ // where factor*scale is implicit
3276+ .setMIFlag (MachineInstr::FrameDestroy);
3277+ MIB.addMemOperand (MF.getMachineMemOperand (
3278+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3279+ MachineMemOperand::MOLoad, Size, Alignment));
3280+ if (NeedsWinCFI)
3281+ InsertSEH (MIB, TII, MachineInstr::FrameDestroy);
3282+ } else {
3283+ MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3284+ if (RPI.isPaired ()) {
3285+ MIB.addReg (Reg2, getDefRegState (true ));
3286+ MIB.addMemOperand (MF.getMachineMemOperand (
3287+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3288+ MachineMemOperand::MOLoad, Size, Alignment));
3289+ }
3290+ MIB.addReg (Reg1, getDefRegState (true ));
3291+ MIB.addReg (AArch64::SP)
3292+ .addImm (RPI.Offset ) // [sp, #offset*scale]
3293+ // where factor*scale is implicit
3294+ .setMIFlag (MachineInstr::FrameDestroy);
3295+ MIB.addMemOperand (MF.getMachineMemOperand (
3296+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3297+ MachineMemOperand::MOLoad, Size, Alignment));
3298+ if (NeedsWinCFI)
3299+ InsertSEH (MIB, TII, MachineInstr::FrameDestroy);
31963300 }
3197- MIB.addReg (Reg1, getDefRegState (true ))
3198- .addReg (AArch64::SP)
3199- .addImm (RPI.Offset ) // [sp, #offset*scale]
3200- // where factor*scale is implicit
3201- .setMIFlag (MachineInstr::FrameDestroy);
3202- MIB.addMemOperand (MF.getMachineMemOperand (
3203- MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3204- MachineMemOperand::MOLoad, Size, Alignment));
3205- if (NeedsWinCFI)
3206- InsertSEH (MIB, TII, MachineInstr::FrameDestroy);
32073301 }
3208-
32093302 return true ;
32103303}
32113304
@@ -3234,6 +3327,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
32343327
32353328 unsigned ExtraCSSpill = 0 ;
32363329 bool HasUnpairedGPR64 = false ;
3330+ bool HasPairZReg = false ;
32373331 // Figure out which callee-saved registers to save/restore.
32383332 for (unsigned i = 0 ; CSRegs[i]; ++i) {
32393333 const unsigned Reg = CSRegs[i];
@@ -3287,6 +3381,28 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
32873381 !RegInfo->isReservedReg (MF, PairedReg))
32883382 ExtraCSSpill = PairedReg;
32893383 }
3384+ // Check if there is a pair of ZRegs, so it can select PReg for spill/fill
3385+ HasPairZReg |= (AArch64::ZPRRegClass.contains (Reg, CSRegs[i ^ 1 ]) &&
3386+ SavedRegs.test (CSRegs[i ^ 1 ]));
3387+ }
3388+
3389+ if (HasPairZReg && (Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ())) {
3390+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3391+ // Find a suitable predicate register for the multi-vector spill/fill
3392+ // instructions.
3393+ unsigned PnReg = findFreePredicateReg (SavedRegs);
3394+ if (PnReg != AArch64::NoRegister)
3395+ AFI->setPredicateRegForFillSpill (PnReg);
3396+ // If no free callee-save has been found assign one.
3397+ if (!AFI->getPredicateRegForFillSpill () &&
3398+ MF.getFunction ().getCallingConv () ==
3399+ CallingConv::AArch64_SVE_VectorCall) {
3400+ SavedRegs.set (AArch64::P8);
3401+ AFI->setPredicateRegForFillSpill (AArch64::PN8);
3402+ }
3403+
3404+ assert (!RegInfo->isReservedReg (MF, AFI->getPredicateRegForFillSpill ()) &&
3405+ " Predicate cannot be a reserved register" );
32903406 }
32913407
32923408 if (MF.getFunction ().getCallingConv () == CallingConv::Win64 &&
0 commit comments