@@ -376,15 +376,18 @@ class AdjointGenerator
376376
377377#if LLVM_VERSION_MAJOR >= 10
378378 void visitLoadLike (llvm::Instruction &I, MaybeAlign alignment,
379- bool constantval, bool can_modref,
380- Value *OrigOffset = nullptr )
379+ bool constantval, Value *OrigOffset = nullptr ,
381380#else
382381 void visitLoadLike (llvm::Instruction &I, unsigned alignment, bool constantval,
383- bool can_modref, Value *OrigOffset = nullptr )
382+ Value *OrigOffset = nullptr ,
384383#endif
385- {
384+ Value *mask = nullptr , Value *orig_maskInit = nullptr ) {
386385 auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
387386
387+ assert (gutils->can_modref_map );
388+ assert (gutils->can_modref_map ->find (&I) != gutils->can_modref_map ->end ());
389+ bool can_modref = gutils->can_modref_map ->find (&I)->second ;
390+
388391 constantval |= gutils->isConstantValue (&I);
389392
390393 BasicBlock *parent = I.getParent ();
@@ -536,7 +539,8 @@ class AdjointGenerator
536539 // the instruction if the value is a potential pointer. This may not be
537540 // caught by type analysis is the result does not have a known type.
538541 if (!gutils->isConstantInstruction (&I)) {
539- bool isfloat = type->isFPOrFPVectorTy ();
542+ Type *isfloat =
543+ type->isFPOrFPVectorTy () ? type->getScalarType () : nullptr ;
540544 if (!isfloat && type->isIntOrIntVectorTy ()) {
541545 auto LoadSize = DL.getTypeSizeInBits (type) / 8 ;
542546 ConcreteType vd = BaseType::Unknown;
@@ -560,8 +564,34 @@ class AdjointGenerator
560564 getForwardBuilder (Builder2);
561565
562566 if (!gutils->isConstantValue (&I)) {
563- auto diff = Builder2.CreateLoad (
564- gutils->invertPointerM (I.getOperand (0 ), Builder2));
567+ Value *diff;
568+ if (!mask) {
569+ auto LI = Builder2.CreateLoad (
570+ gutils->invertPointerM (I.getOperand (0 ), Builder2));
571+ if (alignment)
572+ #if LLVM_VERSION_MAJOR >= 10
573+ LI->setAlignment (*alignment);
574+ #else
575+ LI->setAlignment (alignment);
576+ #endif
577+ diff = LI;
578+ } else {
579+ Type *tys[] = {I.getType (), I.getOperand (0 )->getType ()};
580+ auto F = Intrinsic::getDeclaration (gutils->oldFunc ->getParent (),
581+ Intrinsic::masked_load, tys);
582+ #if LLVM_VERSION_MAJOR >= 10
583+ Value *alignv =
584+ ConstantInt::get (Type::getInt32Ty (mask->getContext ()),
585+ alignment ? alignment->value () : 0 );
586+ #else
587+ Value *alignv = ConstantInt::get (
588+ Type::getInt32Ty (mask->getContext ()), alignment);
589+ #endif
590+ Value *args[] = {
591+ gutils->invertPointerM (I.getOperand (0 ), Builder2), alignv,
592+ mask, diffe (orig_maskInit, Builder2)};
593+ diff = Builder2.CreateCall (F, args);
594+ }
565595 setDiffe (&I, diff, Builder2);
566596 }
567597 break ;
@@ -576,8 +606,13 @@ class AdjointGenerator
576606
577607 if (!gutils->isConstantValue (I.getOperand (0 ))) {
578608 ((DiffeGradientUtils *)gutils)
579- ->addToInvertedPtrDiffe (I.getOperand (0 ), prediff, Builder2,
580- alignment, OrigOffset);
609+ ->addToInvertedPtrDiffe (
610+ I.getOperand (0 ), prediff, Builder2, alignment, OrigOffset,
611+ mask ? lookup (mask, Builder2) : nullptr );
612+ }
613+ if (mask && !gutils->isConstantValue (orig_maskInit)) {
614+ addToDiffe (orig_maskInit, prediff, Builder2, isfloat,
615+ Builder2.CreateNot (mask));
581616 }
582617 break ;
583618 }
@@ -614,10 +649,7 @@ class AdjointGenerator
614649 auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
615650
616651 bool constantval = parseTBAA (LI, DL).Inner0 ().isIntegral ();
617- assert (gutils->can_modref_map );
618- assert (gutils->can_modref_map ->find (&LI) != gutils->can_modref_map ->end ());
619- bool can_modref = gutils->can_modref_map ->find (&LI)->second ;
620- visitLoadLike (LI, alignment, constantval, can_modref);
652+ visitLoadLike (LI, alignment, constantval);
621653 eraseIfUnused (LI);
622654 }
623655
@@ -636,15 +668,9 @@ class AdjointGenerator
636668 }
637669
638670 void visitStoreInst (llvm::StoreInst &SI) {
639- Value *orig_ptr = SI.getPointerOperand ();
640- Value *orig_val = SI.getValueOperand ();
641- Value *val = gutils->getNewFromOriginal (orig_val);
642- Type *valType = orig_val->getType ();
643-
644- auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
645671 // If a store of an omp init argument, don't delete in reverse
646672 // and don't do any adjoint propagation (assumed integral)
647- for (auto U : orig_ptr ->users ()) {
673+ for (auto U : SI. getPointerOperand () ->users ()) {
648674 if (auto CI = dyn_cast<CallInst>(U)) {
649675 if (auto F = CI->getCalledFunction ()) {
650676 if (F->getName () == " __kmpc_for_static_init_4" ||
@@ -656,24 +682,47 @@ class AdjointGenerator
656682 }
657683 }
658684 }
685+ #if LLVM_VERSION_MAJOR >= 10
686+ auto align = SI.getAlign ();
687+ #else
688+ auto align = SI.getAlignment ();
689+ #endif
690+ visitCommonStore (SI, SI.getPointerOperand (), SI.getValueOperand (), align,
691+ SI.isVolatile (), SI.getOrdering (), SI.getSyncScopeID (),
692+ /* mask=*/ nullptr );
693+ eraseIfUnused (SI);
694+ }
695+
696+ #if LLVM_VERSION_MAJOR >= 10
697+ void visitCommonStore (llvm::Instruction &I, Value *orig_ptr, Value *orig_val,
698+ MaybeAlign align, bool isVolatile,
699+ AtomicOrdering ordering, SyncScope::ID syncScope,
700+ Value *mask = nullptr )
701+ #else
702+ void visitCommonStore (llvm::Instruction &I, Value *orig_ptr, Value *orig_val,
703+ unsigned align, bool isVolatile,
704+ AtomicOrdering ordering, SyncScope::ID syncScope,
705+ Value *mask = nullptr )
706+ #endif
707+ {
708+ Value *val = gutils->getNewFromOriginal (orig_val);
709+ Type *valType = orig_val->getType ();
710+
711+ auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
659712
660- if (unnecessaryStores.count (&SI)) {
661- eraseIfUnused (SI);
713+ if (unnecessaryStores.count (&I)) {
662714 return ;
663715 }
664716
665717 if (gutils->isConstantValue (orig_ptr)) {
666- eraseIfUnused (SI);
667718 return ;
668719 }
669720
670721 bool constantval = gutils->isConstantValue (orig_val) ||
671- parseTBAA (SI , DL).Inner0 ().isIntegral ();
722+ parseTBAA (I , DL).Inner0 ().isIntegral ();
672723
673724 // TODO allow recognition of other types that could contain pointers [e.g.
674725 // {void*, void*} or <2 x i64> ]
675- StoreInst *ts = nullptr ;
676-
677726 auto storeSize = DL.getTypeSizeInBits (valType) / 8 ;
678727
679728 // ! Storing a floating point value
@@ -688,12 +737,12 @@ class AdjointGenerator
688737 FT = fp.isFloat ();
689738 } else if (isa<ConstantInt>(orig_val) ||
690739 valType->isIntOrIntVectorTy ()) {
691- llvm::errs () << " assuming type as integral for store: " << SI << " \n " ;
740+ llvm::errs () << " assuming type as integral for store: " << I << " \n " ;
692741 FT = nullptr ;
693742 } else {
694743 TR.firstPointer (storeSize, orig_ptr, /* errifnotfound*/ true ,
695744 /* pointerIntSame*/ true );
696- llvm::errs () << " cannot deduce type of store " << SI << " \n " ;
745+ llvm::errs () << " cannot deduce type of store " << I << " \n " ;
697746 assert (0 && " cannot deduce" );
698747 }
699748 } else {
@@ -710,35 +759,61 @@ class AdjointGenerator
710759 break ;
711760 case DerivativeMode::ReverseModeGradient:
712761 case DerivativeMode::ReverseModeCombined: {
713- IRBuilder<> Builder2 (SI .getParent ());
762+ IRBuilder<> Builder2 (I .getParent ());
714763 getReverseBuilder (Builder2);
715764
716765 if (constantval) {
717- ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
766+ gutils->setPtrDiffe (orig_ptr, Constant::getNullValue (valType),
767+ Builder2, align, isVolatile, ordering, syncScope,
768+ mask);
718769 } else {
719- auto dif1 = Builder2.CreateLoad (
720- lookup (gutils->invertPointerM (orig_ptr, Builder2), Builder2));
770+ Value *diff;
771+ if (!mask) {
772+ auto dif1 = Builder2.CreateLoad (
773+ lookup (gutils->invertPointerM (orig_ptr, Builder2), Builder2),
774+ isVolatile);
775+ if (align)
776+ #if LLVM_VERSION_MAJOR >= 10
777+ dif1->setAlignment (*align);
778+ #else
779+ dif1->setAlignment (align);
780+ #endif
781+ dif1->setOrdering (ordering);
782+ dif1->setSyncScopeID (syncScope);
783+ diff = dif1;
784+ } else {
785+ mask = lookup (mask, Builder2);
786+ Type *tys[] = {valType, orig_ptr->getType ()};
787+ auto F = Intrinsic::getDeclaration (gutils->oldFunc ->getParent (),
788+ Intrinsic::masked_load, tys);
721789#if LLVM_VERSION_MAJOR >= 10
722- dif1->setAlignment (SI.getAlign ());
790+ Value *alignv =
791+ ConstantInt::get (Type::getInt32Ty (mask->getContext ()),
792+ align ? align->value () : 0 );
723793#else
724- dif1->setAlignment (SI.getAlignment ());
794+ Value *alignv =
795+ ConstantInt::get (Type::getInt32Ty (mask->getContext ()), align);
725796#endif
726- ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
727- addToDiffe (orig_val, dif1, Builder2, FT);
797+ Value *args[] = {
798+ lookup (gutils->invertPointerM (orig_ptr, Builder2), Builder2),
799+ alignv, mask, Constant::getNullValue (valType)};
800+ diff = Builder2.CreateCall (F, args);
801+ }
802+ gutils->setPtrDiffe (orig_ptr, Constant::getNullValue (valType),
803+ Builder2, align, isVolatile, ordering, syncScope,
804+ mask);
805+ addToDiffe (orig_val, diff, Builder2, FT, mask);
728806 }
729807 break ;
730808 }
731809 case DerivativeMode::ForwardMode: {
732- IRBuilder<> Builder2 (&SI );
810+ IRBuilder<> Builder2 (&I );
733811 getForwardBuilder (Builder2);
734812
735- if (constantval) {
736- ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
737- } else {
738- auto diff = diffe (orig_val, Builder2);
739-
740- ts = setPtrDiffe (orig_ptr, diff, Builder2);
741- }
813+ Value *diff = constantval ? Constant::getNullValue (valType)
814+ : diffe (orig_val, Builder2);
815+ gutils->setPtrDiffe (orig_ptr, diff, Builder2, align, isVolatile,
816+ ordering, syncScope, mask);
742817 break ;
743818 }
744819 }
@@ -749,7 +824,7 @@ class AdjointGenerator
749824 if (Mode == DerivativeMode::ReverseModePrimal ||
750825 Mode == DerivativeMode::ReverseModeCombined ||
751826 Mode == DerivativeMode::ForwardMode) {
752- IRBuilder<> storeBuilder (gutils->getNewFromOriginal (&SI ));
827+ IRBuilder<> storeBuilder (gutils->getNewFromOriginal (&I ));
753828
754829 Value *valueop = nullptr ;
755830
@@ -758,21 +833,10 @@ class AdjointGenerator
758833 } else {
759834 valueop = gutils->invertPointerM (orig_val, storeBuilder);
760835 }
761- ts = setPtrDiffe (orig_ptr, valueop, storeBuilder);
836+ gutils->setPtrDiffe (orig_ptr, valueop, storeBuilder, align, isVolatile,
837+ ordering, syncScope, mask);
762838 }
763839 }
764-
765- if (ts) {
766- #if LLVM_VERSION_MAJOR >= 10
767- ts->setAlignment (SI.getAlign ());
768- #else
769- ts->setAlignment (SI.getAlignment ());
770- #endif
771- ts->setVolatile (SI.isVolatile ());
772- ts->setOrdering (SI.getOrdering ());
773- ts->setSyncScopeID (SI.getSyncScopeID ());
774- }
775- eraseIfUnused (SI);
776840 }
777841
778842 void visitGetElementPtrInst (llvm::GetElementPtrInst &gep) {
@@ -1366,13 +1430,11 @@ class AdjointGenerator
13661430 ((DiffeGradientUtils *)gutils)->setDiffe (val, dif, Builder);
13671431 }
13681432
1369- StoreInst *setPtrDiffe (Value *val, Value *dif, IRBuilder<> &Builder) {
1370- return gutils->setPtrDiffe (val, dif, Builder);
1371- }
1372-
13731433 std::vector<SelectInst *> addToDiffe (Value *val, Value *dif,
1374- IRBuilder<> &Builder, Type *T) {
1375- return ((DiffeGradientUtils *)gutils)->addToDiffe (val, dif, Builder, T);
1434+ IRBuilder<> &Builder, Type *T,
1435+ Value *mask = nullptr ) {
1436+ return ((DiffeGradientUtils *)gutils)
1437+ ->addToDiffe (val, dif, Builder, T, /* idxs*/ {}, mask);
13761438 }
13771439
13781440 Value *lookup (Value *val, IRBuilder<> &Builder) {
@@ -2351,17 +2413,45 @@ class AdjointGenerator
23512413 auto CI = cast<ConstantInt>(I.getOperand (1 ));
23522414#if LLVM_VERSION_MAJOR >= 10
23532415 visitLoadLike (I, /* Align*/ MaybeAlign (CI->getZExtValue ()),
2354- /* constantval*/ false ,
2355- /* can_modref*/ false );
2416+ /* constantval*/ false );
23562417#else
2357- visitLoadLike (I, /* Align*/ CI->getZExtValue (), /* constantval*/ false ,
2358- /* can_modref*/ false );
2418+ visitLoadLike (I, /* Align*/ CI->getZExtValue (), /* constantval*/ false );
23592419#endif
23602420 return ;
23612421 }
23622422 default :
23632423 break ;
23642424 }
2425+
2426+ if (ID == Intrinsic::masked_store) {
2427+ auto align0 = cast<ConstantInt>(I.getOperand (2 ))->getZExtValue ();
2428+ #if LLVM_VERSION_MAJOR >= 10
2429+ auto align = MaybeAlign (align0);
2430+ #else
2431+ auto align = align0;
2432+ #endif
2433+ visitCommonStore (I, /* orig_ptr*/ I.getOperand (1 ),
2434+ /* orig_val*/ I.getOperand (0 ), align,
2435+ /* isVolatile*/ false , llvm::AtomicOrdering::NotAtomic,
2436+ SyncScope::SingleThread,
2437+ /* mask*/ gutils->getNewFromOriginal (I.getOperand (3 )));
2438+ return ;
2439+ }
2440+ if (ID == Intrinsic::masked_load) {
2441+ auto align0 = cast<ConstantInt>(I.getOperand (1 ))->getZExtValue ();
2442+ #if LLVM_VERSION_MAJOR >= 10
2443+ auto align = MaybeAlign (align0);
2444+ #else
2445+ auto align = align0;
2446+ #endif
2447+ auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
2448+ bool constantval = parseTBAA (I, DL).Inner0 ().isIntegral ();
2449+ visitLoadLike (I, align, constantval, /* OrigOffset*/ nullptr ,
2450+ /* mask*/ gutils->getNewFromOriginal (I.getOperand (2 )),
2451+ /* orig_maskInit*/ I.getOperand (3 ));
2452+ return ;
2453+ }
2454+
23652455 switch (Mode) {
23662456 case DerivativeMode::ReverseModePrimal: {
23672457 switch (ID) {
0 commit comments