@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
810810 OpBuilder::InsertionGuard g (rewriter);
811811 rewriter.setInsertionPointAfter (hoistedPackedTensor.getDefiningOp ());
812812
813- std::optional<unsigned > maybeOperandNumber =
814- forOp.getIterArgNumberForOpOperand (*pUse);
815- assert (maybeOperandNumber.has_value () && " expected a proper iter arg number" );
816-
817- int64_t operandNumber = maybeOperandNumber.value ();
813+ unsigned iterArgNumber = forOp.getResultForOpOperand (*pUse).getResultNumber ();
818814 auto yieldOp = cast<scf::YieldOp>(forOp.getBody (0 )->getTerminator ());
819- auto yieldingExtractSliceOp = yieldOp->getOperand (operandNumber )
815+ auto yieldingExtractSliceOp = yieldOp->getOperand (iterArgNumber )
820816 .getDefiningOp <tensor::ExtractSliceOp>();
821817 if (!yieldingExtractSliceOp)
822818 return tensor::ExtractSliceOp ();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
829825 return tensor::ExtractSliceOp ();
830826
831827 SmallVector<Value> initArgs = forOp.getInitArgs ();
832- initArgs[operandNumber ] = hoistedPackedTensor;
828+ initArgs[iterArgNumber ] = hoistedPackedTensor;
833829 SmallVector<Value> yieldOperands = yieldOp.getOperands ();
834- yieldOperands[operandNumber ] = yieldingExtractSliceOp.getSource ();
830+ yieldOperands[iterArgNumber ] = yieldingExtractSliceOp.getSource ();
835831
836832 int64_t numOriginalForOpResults = initArgs.size ();
837833 LLVM_DEBUG (DBGS () << " numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
844840 hoistedPackedTensor.getLoc (), hoistedPackedTensor,
845841 outerSliceOp.getMixedOffsets (), outerSliceOp.getMixedSizes (),
846842 outerSliceOp.getMixedStrides ());
847- rewriter.replaceAllUsesWith (forOp.getResult (operandNumber ), extracted);
843+ rewriter.replaceAllUsesWith (forOp.getResult (iterArgNumber ), extracted);
848844 }
849845 scf::ForOp newForOp =
850846 replaceLoopWithNewYields (rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
853849 << " \n " );
854850 LLVM_DEBUG (DBGS () << " replace source of: " << extracted << " \n " );
855851 LLVM_DEBUG (DBGS () << " with result #"
856- << numOriginalForOpResults + operandNumber
852+ << numOriginalForOpResults + iterArgNumber
857853 << " of forOp, giving us: " << extracted << " \n " );
858854 rewriter.startRootUpdate (extracted);
859855 extracted.getSourceMutable ().assign (
860- newForOp.getResult (numOriginalForOpResults + operandNumber ));
856+ newForOp.getResult (numOriginalForOpResults + iterArgNumber ));
861857 rewriter.finalizeRootUpdate (extracted);
862858
863859 LLVM_DEBUG (DBGS () << " replace uses of: " << paddedValueBeforeHoisting
864860 << " \n " );
865861 LLVM_DEBUG (DBGS () << " with region iter arg #"
866- << numOriginalForOpResults + operandNumber << " \n " );
862+ << numOriginalForOpResults + iterArgNumber << " \n " );
867863 rewriter.replaceAllUsesWith (
868864 paddedValueBeforeHoisting,
869- newForOp.getRegionIterArg (numOriginalForOpResults + operandNumber ));
865+ newForOp.getRegionIterArg (numOriginalForOpResults + iterArgNumber ));
870866
871867 return extracted;
872868}
0 commit comments