ScalarizeMaskedMemIntrin.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889
  1. //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
  2. // instrinsics
  3. //
  4. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  5. // See https://llvm.org/LICENSE.txt for license information.
  6. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This pass replaces masked memory intrinsics - when unsupported by the target
  11. // - with a chain of basic blocks, that deal with the elements one-by-one if the
  12. // appropriate mask bit is set.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/ADT/Twine.h"
  16. #include "llvm/Analysis/TargetTransformInfo.h"
  17. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  18. #include "llvm/IR/BasicBlock.h"
  19. #include "llvm/IR/Constant.h"
  20. #include "llvm/IR/Constants.h"
  21. #include "llvm/IR/DerivedTypes.h"
  22. #include "llvm/IR/Function.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/InstrTypes.h"
  25. #include "llvm/IR/Instruction.h"
  26. #include "llvm/IR/Instructions.h"
  27. #include "llvm/IR/IntrinsicInst.h"
  28. #include "llvm/IR/Intrinsics.h"
  29. #include "llvm/IR/Type.h"
  30. #include "llvm/IR/Value.h"
  31. #include "llvm/Pass.h"
  32. #include "llvm/Support/Casting.h"
  33. #include <algorithm>
  34. #include <cassert>
  35. using namespace llvm;
  36. #define DEBUG_TYPE "scalarize-masked-mem-intrin"
  37. namespace {
  38. class ScalarizeMaskedMemIntrin : public FunctionPass {
  39. const TargetTransformInfo *TTI = nullptr;
  40. public:
  41. static char ID; // Pass identification, replacement for typeid
  42. explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
  43. initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
  44. }
  45. bool runOnFunction(Function &F) override;
  46. StringRef getPassName() const override {
  47. return "Scalarize Masked Memory Intrinsics";
  48. }
  49. void getAnalysisUsage(AnalysisUsage &AU) const override {
  50. AU.addRequired<TargetTransformInfoWrapperPass>();
  51. }
  52. private:
  53. bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
  54. bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
  55. };
  56. } // end anonymous namespace
  57. char ScalarizeMaskedMemIntrin::ID = 0;
  58. INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
  59. "Scalarize unsupported masked memory intrinsics", false, false)
  60. FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
  61. return new ScalarizeMaskedMemIntrin();
  62. }
  63. static bool isConstantIntVector(Value *Mask) {
  64. Constant *C = dyn_cast<Constant>(Mask);
  65. if (!C)
  66. return false;
  67. unsigned NumElts = Mask->getType()->getVectorNumElements();
  68. for (unsigned i = 0; i != NumElts; ++i) {
  69. Constant *CElt = C->getAggregateElement(i);
  70. if (!CElt || !isa<ConstantInt>(CElt))
  71. return false;
  72. }
  73. return true;
  74. }
  75. // Translate a masked load intrinsic like
  76. // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
  77. // <16 x i1> %mask, <16 x i32> %passthru)
  78. // to a chain of basic blocks, with loading element one-by-one if
  79. // the appropriate mask bit is set
  80. //
  81. // %1 = bitcast i8* %addr to i32*
  82. // %2 = extractelement <16 x i1> %mask, i32 0
  83. // br i1 %2, label %cond.load, label %else
  84. //
  85. // cond.load: ; preds = %0
  86. // %3 = getelementptr i32* %1, i32 0
  87. // %4 = load i32* %3
  88. // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
  89. // br label %else
  90. //
  91. // else: ; preds = %0, %cond.load
  92. // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
  93. // %6 = extractelement <16 x i1> %mask, i32 1
  94. // br i1 %6, label %cond.load1, label %else2
  95. //
  96. // cond.load1: ; preds = %else
  97. // %7 = getelementptr i32* %1, i32 1
  98. // %8 = load i32* %7
  99. // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
  100. // br label %else2
  101. //
  102. // else2: ; preds = %else, %cond.load1
  103. // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
  104. // %10 = extractelement <16 x i1> %mask, i32 2
  105. // br i1 %10, label %cond.load4, label %else5
  106. //
  107. static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
  108. Value *Ptr = CI->getArgOperand(0);
  109. Value *Alignment = CI->getArgOperand(1);
  110. Value *Mask = CI->getArgOperand(2);
  111. Value *Src0 = CI->getArgOperand(3);
  112. unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
  113. VectorType *VecType = cast<VectorType>(CI->getType());
  114. Type *EltTy = VecType->getElementType();
  115. IRBuilder<> Builder(CI->getContext());
  116. Instruction *InsertPt = CI;
  117. BasicBlock *IfBlock = CI->getParent();
  118. Builder.SetInsertPoint(InsertPt);
  119. Builder.SetCurrentDebugLocation(CI->getDebugLoc());
  120. // Short-cut if the mask is all-true.
  121. if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
  122. Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
  123. CI->replaceAllUsesWith(NewI);
  124. CI->eraseFromParent();
  125. return;
  126. }
  127. // Adjust alignment for the scalar instruction.
  128. AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
  129. // Bitcast %addr from i8* to EltTy*
  130. Type *NewPtrType =
  131. EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
  132. Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
  133. unsigned VectorWidth = VecType->getNumElements();
  134. // The result vector
  135. Value *VResult = Src0;
  136. if (isConstantIntVector(Mask)) {
  137. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  138. if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
  139. continue;
  140. Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
  141. LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
  142. VResult = Builder.CreateInsertElement(VResult, Load, Idx);
  143. }
  144. CI->replaceAllUsesWith(VResult);
  145. CI->eraseFromParent();
  146. return;
  147. }
  148. // If the mask is not v1i1, use scalar bit test operations. This generates
  149. // better results on X86 at least.
  150. Value *SclrMask;
  151. if (VectorWidth != 1) {
  152. Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
  153. SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
  154. }
  155. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  156. // Fill the "else" block, created in the previous iteration
  157. //
  158. // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
  159. // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
  160. // %cond = icmp ne i16 %mask_1, 0
  161. // br i1 %mask_1, label %cond.load, label %else
  162. //
  163. Value *Predicate;
  164. if (VectorWidth != 1) {
  165. Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
  166. Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
  167. Builder.getIntN(VectorWidth, 0));
  168. } else {
  169. Predicate = Builder.CreateExtractElement(Mask, Idx);
  170. }
  171. // Create "cond" block
  172. //
  173. // %EltAddr = getelementptr i32* %1, i32 0
  174. // %Elt = load i32* %EltAddr
  175. // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
  176. //
  177. BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
  178. "cond.load");
  179. Builder.SetInsertPoint(InsertPt);
  180. Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
  181. LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
  182. Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
  183. // Create "else" block, fill it in the next iteration
  184. BasicBlock *NewIfBlock =
  185. CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
  186. Builder.SetInsertPoint(InsertPt);
  187. Instruction *OldBr = IfBlock->getTerminator();
  188. BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
  189. OldBr->eraseFromParent();
  190. BasicBlock *PrevIfBlock = IfBlock;
  191. IfBlock = NewIfBlock;
  192. // Create the phi to join the new and previous value.
  193. PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
  194. Phi->addIncoming(NewVResult, CondBlock);
  195. Phi->addIncoming(VResult, PrevIfBlock);
  196. VResult = Phi;
  197. }
  198. CI->replaceAllUsesWith(VResult);
  199. CI->eraseFromParent();
  200. ModifiedDT = true;
  201. }
  202. // Translate a masked store intrinsic, like
  203. // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
  204. // <16 x i1> %mask)
  205. // to a chain of basic blocks, that stores element one-by-one if
  206. // the appropriate mask bit is set
  207. //
  208. // %1 = bitcast i8* %addr to i32*
  209. // %2 = extractelement <16 x i1> %mask, i32 0
  210. // br i1 %2, label %cond.store, label %else
  211. //
  212. // cond.store: ; preds = %0
  213. // %3 = extractelement <16 x i32> %val, i32 0
  214. // %4 = getelementptr i32* %1, i32 0
  215. // store i32 %3, i32* %4
  216. // br label %else
  217. //
  218. // else: ; preds = %0, %cond.store
  219. // %5 = extractelement <16 x i1> %mask, i32 1
  220. // br i1 %5, label %cond.store1, label %else2
  221. //
  222. // cond.store1: ; preds = %else
  223. // %6 = extractelement <16 x i32> %val, i32 1
  224. // %7 = getelementptr i32* %1, i32 1
  225. // store i32 %6, i32* %7
  226. // br label %else2
  227. // . . .
  228. static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
  229. Value *Src = CI->getArgOperand(0);
  230. Value *Ptr = CI->getArgOperand(1);
  231. Value *Alignment = CI->getArgOperand(2);
  232. Value *Mask = CI->getArgOperand(3);
  233. unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
  234. VectorType *VecType = cast<VectorType>(Src->getType());
  235. Type *EltTy = VecType->getElementType();
  236. IRBuilder<> Builder(CI->getContext());
  237. Instruction *InsertPt = CI;
  238. BasicBlock *IfBlock = CI->getParent();
  239. Builder.SetInsertPoint(InsertPt);
  240. Builder.SetCurrentDebugLocation(CI->getDebugLoc());
  241. // Short-cut if the mask is all-true.
  242. if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
  243. Builder.CreateAlignedStore(Src, Ptr, AlignVal);
  244. CI->eraseFromParent();
  245. return;
  246. }
  247. // Adjust alignment for the scalar instruction.
  248. AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
  249. // Bitcast %addr from i8* to EltTy*
  250. Type *NewPtrType =
  251. EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
  252. Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
  253. unsigned VectorWidth = VecType->getNumElements();
  254. if (isConstantIntVector(Mask)) {
  255. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  256. if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
  257. continue;
  258. Value *OneElt = Builder.CreateExtractElement(Src, Idx);
  259. Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
  260. Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
  261. }
  262. CI->eraseFromParent();
  263. return;
  264. }
  265. // If the mask is not v1i1, use scalar bit test operations. This generates
  266. // better results on X86 at least.
  267. Value *SclrMask;
  268. if (VectorWidth != 1) {
  269. Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
  270. SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
  271. }
  272. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  273. // Fill the "else" block, created in the previous iteration
  274. //
  275. // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
  276. // %cond = icmp ne i16 %mask_1, 0
  277. // br i1 %mask_1, label %cond.store, label %else
  278. //
  279. Value *Predicate;
  280. if (VectorWidth != 1) {
  281. Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
  282. Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
  283. Builder.getIntN(VectorWidth, 0));
  284. } else {
  285. Predicate = Builder.CreateExtractElement(Mask, Idx);
  286. }
  287. // Create "cond" block
  288. //
  289. // %OneElt = extractelement <16 x i32> %Src, i32 Idx
  290. // %EltAddr = getelementptr i32* %1, i32 0
  291. // %store i32 %OneElt, i32* %EltAddr
  292. //
  293. BasicBlock *CondBlock =
  294. IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
  295. Builder.SetInsertPoint(InsertPt);
  296. Value *OneElt = Builder.CreateExtractElement(Src, Idx);
  297. Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
  298. Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
  299. // Create "else" block, fill it in the next iteration
  300. BasicBlock *NewIfBlock =
  301. CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
  302. Builder.SetInsertPoint(InsertPt);
  303. Instruction *OldBr = IfBlock->getTerminator();
  304. BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
  305. OldBr->eraseFromParent();
  306. IfBlock = NewIfBlock;
  307. }
  308. CI->eraseFromParent();
  309. ModifiedDT = true;
  310. }
  311. // Translate a masked gather intrinsic like
  312. // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
  313. // <16 x i1> %Mask, <16 x i32> %Src)
  314. // to a chain of basic blocks, with loading element one-by-one if
  315. // the appropriate mask bit is set
  316. //
  317. // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
  318. // %Mask0 = extractelement <16 x i1> %Mask, i32 0
  319. // br i1 %Mask0, label %cond.load, label %else
  320. //
  321. // cond.load:
  322. // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
  323. // %Load0 = load i32, i32* %Ptr0, align 4
  324. // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
  325. // br label %else
  326. //
  327. // else:
  328. // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
  329. // %Mask1 = extractelement <16 x i1> %Mask, i32 1
  330. // br i1 %Mask1, label %cond.load1, label %else2
  331. //
  332. // cond.load1:
  333. // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
  334. // %Load1 = load i32, i32* %Ptr1, align 4
  335. // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
  336. // br label %else2
  337. // . . .
  338. // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
  339. // ret <16 x i32> %Result
  340. static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
  341. Value *Ptrs = CI->getArgOperand(0);
  342. Value *Alignment = CI->getArgOperand(1);
  343. Value *Mask = CI->getArgOperand(2);
  344. Value *Src0 = CI->getArgOperand(3);
  345. VectorType *VecType = cast<VectorType>(CI->getType());
  346. Type *EltTy = VecType->getElementType();
  347. IRBuilder<> Builder(CI->getContext());
  348. Instruction *InsertPt = CI;
  349. BasicBlock *IfBlock = CI->getParent();
  350. Builder.SetInsertPoint(InsertPt);
  351. unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
  352. Builder.SetCurrentDebugLocation(CI->getDebugLoc());
  353. // The result vector
  354. Value *VResult = Src0;
  355. unsigned VectorWidth = VecType->getNumElements();
  356. // Shorten the way if the mask is a vector of constants.
  357. if (isConstantIntVector(Mask)) {
  358. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  359. if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
  360. continue;
  361. Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
  362. LoadInst *Load =
  363. Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
  364. VResult =
  365. Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
  366. }
  367. CI->replaceAllUsesWith(VResult);
  368. CI->eraseFromParent();
  369. return;
  370. }
  371. // If the mask is not v1i1, use scalar bit test operations. This generates
  372. // better results on X86 at least.
  373. Value *SclrMask;
  374. if (VectorWidth != 1) {
  375. Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
  376. SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
  377. }
  378. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  379. // Fill the "else" block, created in the previous iteration
  380. //
  381. // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
  382. // %cond = icmp ne i16 %mask_1, 0
  383. // br i1 %Mask1, label %cond.load, label %else
  384. //
  385. Value *Predicate;
  386. if (VectorWidth != 1) {
  387. Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
  388. Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
  389. Builder.getIntN(VectorWidth, 0));
  390. } else {
  391. Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
  392. }
  393. // Create "cond" block
  394. //
  395. // %EltAddr = getelementptr i32* %1, i32 0
  396. // %Elt = load i32* %EltAddr
  397. // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
  398. //
  399. BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
  400. Builder.SetInsertPoint(InsertPt);
  401. Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
  402. LoadInst *Load =
  403. Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
  404. Value *NewVResult =
  405. Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
  406. // Create "else" block, fill it in the next iteration
  407. BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
  408. Builder.SetInsertPoint(InsertPt);
  409. Instruction *OldBr = IfBlock->getTerminator();
  410. BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
  411. OldBr->eraseFromParent();
  412. BasicBlock *PrevIfBlock = IfBlock;
  413. IfBlock = NewIfBlock;
  414. PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
  415. Phi->addIncoming(NewVResult, CondBlock);
  416. Phi->addIncoming(VResult, PrevIfBlock);
  417. VResult = Phi;
  418. }
  419. CI->replaceAllUsesWith(VResult);
  420. CI->eraseFromParent();
  421. ModifiedDT = true;
  422. }
  423. // Translate a masked scatter intrinsic, like
  424. // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
  425. // <16 x i1> %Mask)
  426. // to a chain of basic blocks, that stores element one-by-one if
  427. // the appropriate mask bit is set.
  428. //
  429. // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
  430. // %Mask0 = extractelement <16 x i1> %Mask, i32 0
  431. // br i1 %Mask0, label %cond.store, label %else
  432. //
  433. // cond.store:
  434. // %Elt0 = extractelement <16 x i32> %Src, i32 0
  435. // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
  436. // store i32 %Elt0, i32* %Ptr0, align 4
  437. // br label %else
  438. //
  439. // else:
  440. // %Mask1 = extractelement <16 x i1> %Mask, i32 1
  441. // br i1 %Mask1, label %cond.store1, label %else2
  442. //
  443. // cond.store1:
  444. // %Elt1 = extractelement <16 x i32> %Src, i32 1
  445. // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
  446. // store i32 %Elt1, i32* %Ptr1, align 4
  447. // br label %else2
  448. // . . .
  449. static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
  450. Value *Src = CI->getArgOperand(0);
  451. Value *Ptrs = CI->getArgOperand(1);
  452. Value *Alignment = CI->getArgOperand(2);
  453. Value *Mask = CI->getArgOperand(3);
  454. assert(isa<VectorType>(Src->getType()) &&
  455. "Unexpected data type in masked scatter intrinsic");
  456. assert(isa<VectorType>(Ptrs->getType()) &&
  457. isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
  458. "Vector of pointers is expected in masked scatter intrinsic");
  459. IRBuilder<> Builder(CI->getContext());
  460. Instruction *InsertPt = CI;
  461. BasicBlock *IfBlock = CI->getParent();
  462. Builder.SetInsertPoint(InsertPt);
  463. Builder.SetCurrentDebugLocation(CI->getDebugLoc());
  464. unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
  465. unsigned VectorWidth = Src->getType()->getVectorNumElements();
  466. // Shorten the way if the mask is a vector of constants.
  467. if (isConstantIntVector(Mask)) {
  468. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  469. if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
  470. continue;
  471. Value *OneElt =
  472. Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
  473. Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
  474. Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
  475. }
  476. CI->eraseFromParent();
  477. return;
  478. }
  479. // If the mask is not v1i1, use scalar bit test operations. This generates
  480. // better results on X86 at least.
  481. Value *SclrMask;
  482. if (VectorWidth != 1) {
  483. Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
  484. SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
  485. }
  486. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  487. // Fill the "else" block, created in the previous iteration
  488. //
  489. // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
  490. // %cond = icmp ne i16 %mask_1, 0
  491. // br i1 %Mask1, label %cond.store, label %else
  492. //
  493. Value *Predicate;
  494. if (VectorWidth != 1) {
  495. Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
  496. Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
  497. Builder.getIntN(VectorWidth, 0));
  498. } else {
  499. Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
  500. }
  501. // Create "cond" block
  502. //
  503. // %Elt1 = extractelement <16 x i32> %Src, i32 1
  504. // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
  505. // %store i32 %Elt1, i32* %Ptr1
  506. //
  507. BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
  508. Builder.SetInsertPoint(InsertPt);
  509. Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
  510. Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
  511. Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
  512. // Create "else" block, fill it in the next iteration
  513. BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
  514. Builder.SetInsertPoint(InsertPt);
  515. Instruction *OldBr = IfBlock->getTerminator();
  516. BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
  517. OldBr->eraseFromParent();
  518. IfBlock = NewIfBlock;
  519. }
  520. CI->eraseFromParent();
  521. ModifiedDT = true;
  522. }
  523. static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
  524. Value *Ptr = CI->getArgOperand(0);
  525. Value *Mask = CI->getArgOperand(1);
  526. Value *PassThru = CI->getArgOperand(2);
  527. VectorType *VecType = cast<VectorType>(CI->getType());
  528. Type *EltTy = VecType->getElementType();
  529. IRBuilder<> Builder(CI->getContext());
  530. Instruction *InsertPt = CI;
  531. BasicBlock *IfBlock = CI->getParent();
  532. Builder.SetInsertPoint(InsertPt);
  533. Builder.SetCurrentDebugLocation(CI->getDebugLoc());
  534. unsigned VectorWidth = VecType->getNumElements();
  535. // The result vector
  536. Value *VResult = PassThru;
  537. // Shorten the way if the mask is a vector of constants.
  538. if (isConstantIntVector(Mask)) {
  539. unsigned MemIndex = 0;
  540. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  541. if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
  542. continue;
  543. Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
  544. LoadInst *Load =
  545. Builder.CreateAlignedLoad(EltTy, NewPtr, 1, "Load" + Twine(Idx));
  546. VResult =
  547. Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
  548. ++MemIndex;
  549. }
  550. CI->replaceAllUsesWith(VResult);
  551. CI->eraseFromParent();
  552. return;
  553. }
  554. // If the mask is not v1i1, use scalar bit test operations. This generates
  555. // better results on X86 at least.
  556. Value *SclrMask;
  557. if (VectorWidth != 1) {
  558. Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
  559. SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
  560. }
  561. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  562. // Fill the "else" block, created in the previous iteration
  563. //
  564. // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
  565. // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
  566. // br i1 %mask_1, label %cond.load, label %else
  567. //
  568. Value *Predicate;
  569. if (VectorWidth != 1) {
  570. Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
  571. Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
  572. Builder.getIntN(VectorWidth, 0));
  573. } else {
  574. Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
  575. }
  576. // Create "cond" block
  577. //
  578. // %EltAddr = getelementptr i32* %1, i32 0
  579. // %Elt = load i32* %EltAddr
  580. // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
  581. //
  582. BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
  583. "cond.load");
  584. Builder.SetInsertPoint(InsertPt);
  585. LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
  586. Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
  587. // Move the pointer if there are more blocks to come.
  588. Value *NewPtr;
  589. if ((Idx + 1) != VectorWidth)
  590. NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
  591. // Create "else" block, fill it in the next iteration
  592. BasicBlock *NewIfBlock =
  593. CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
  594. Builder.SetInsertPoint(InsertPt);
  595. Instruction *OldBr = IfBlock->getTerminator();
  596. BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
  597. OldBr->eraseFromParent();
  598. BasicBlock *PrevIfBlock = IfBlock;
  599. IfBlock = NewIfBlock;
  600. // Create the phi to join the new and previous value.
  601. PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
  602. ResultPhi->addIncoming(NewVResult, CondBlock);
  603. ResultPhi->addIncoming(VResult, PrevIfBlock);
  604. VResult = ResultPhi;
  605. // Add a PHI for the pointer if this isn't the last iteration.
  606. if ((Idx + 1) != VectorWidth) {
  607. PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
  608. PtrPhi->addIncoming(NewPtr, CondBlock);
  609. PtrPhi->addIncoming(Ptr, PrevIfBlock);
  610. Ptr = PtrPhi;
  611. }
  612. }
  613. CI->replaceAllUsesWith(VResult);
  614. CI->eraseFromParent();
  615. ModifiedDT = true;
  616. }
  617. static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
  618. Value *Src = CI->getArgOperand(0);
  619. Value *Ptr = CI->getArgOperand(1);
  620. Value *Mask = CI->getArgOperand(2);
  621. VectorType *VecType = cast<VectorType>(Src->getType());
  622. IRBuilder<> Builder(CI->getContext());
  623. Instruction *InsertPt = CI;
  624. BasicBlock *IfBlock = CI->getParent();
  625. Builder.SetInsertPoint(InsertPt);
  626. Builder.SetCurrentDebugLocation(CI->getDebugLoc());
  627. Type *EltTy = VecType->getVectorElementType();
  628. unsigned VectorWidth = VecType->getNumElements();
  629. // Shorten the way if the mask is a vector of constants.
  630. if (isConstantIntVector(Mask)) {
  631. unsigned MemIndex = 0;
  632. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  633. if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
  634. continue;
  635. Value *OneElt =
  636. Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
  637. Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
  638. Builder.CreateAlignedStore(OneElt, NewPtr, 1);
  639. ++MemIndex;
  640. }
  641. CI->eraseFromParent();
  642. return;
  643. }
  644. // If the mask is not v1i1, use scalar bit test operations. This generates
  645. // better results on X86 at least.
  646. Value *SclrMask;
  647. if (VectorWidth != 1) {
  648. Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
  649. SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
  650. }
  651. for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
  652. // Fill the "else" block, created in the previous iteration
  653. //
  654. // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
  655. // br i1 %mask_1, label %cond.store, label %else
  656. //
  657. Value *Predicate;
  658. if (VectorWidth != 1) {
  659. Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
  660. Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
  661. Builder.getIntN(VectorWidth, 0));
  662. } else {
  663. Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
  664. }
  665. // Create "cond" block
  666. //
  667. // %OneElt = extractelement <16 x i32> %Src, i32 Idx
  668. // %EltAddr = getelementptr i32* %1, i32 0
  669. // %store i32 %OneElt, i32* %EltAddr
  670. //
  671. BasicBlock *CondBlock =
  672. IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
  673. Builder.SetInsertPoint(InsertPt);
  674. Value *OneElt = Builder.CreateExtractElement(Src, Idx);
  675. Builder.CreateAlignedStore(OneElt, Ptr, 1);
  676. // Move the pointer if there are more blocks to come.
  677. Value *NewPtr;
  678. if ((Idx + 1) != VectorWidth)
  679. NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
  680. // Create "else" block, fill it in the next iteration
  681. BasicBlock *NewIfBlock =
  682. CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
  683. Builder.SetInsertPoint(InsertPt);
  684. Instruction *OldBr = IfBlock->getTerminator();
  685. BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
  686. OldBr->eraseFromParent();
  687. BasicBlock *PrevIfBlock = IfBlock;
  688. IfBlock = NewIfBlock;
  689. // Add a PHI for the pointer if this isn't the last iteration.
  690. if ((Idx + 1) != VectorWidth) {
  691. PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
  692. PtrPhi->addIncoming(NewPtr, CondBlock);
  693. PtrPhi->addIncoming(Ptr, PrevIfBlock);
  694. Ptr = PtrPhi;
  695. }
  696. }
  697. CI->eraseFromParent();
  698. ModifiedDT = true;
  699. }
  700. bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
  701. bool EverMadeChange = false;
  702. TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  703. bool MadeChange = true;
  704. while (MadeChange) {
  705. MadeChange = false;
  706. for (Function::iterator I = F.begin(); I != F.end();) {
  707. BasicBlock *BB = &*I++;
  708. bool ModifiedDTOnIteration = false;
  709. MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
  710. // Restart BB iteration if the dominator tree of the Function was changed
  711. if (ModifiedDTOnIteration)
  712. break;
  713. }
  714. EverMadeChange |= MadeChange;
  715. }
  716. return EverMadeChange;
  717. }
  718. bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
  719. bool MadeChange = false;
  720. BasicBlock::iterator CurInstIterator = BB.begin();
  721. while (CurInstIterator != BB.end()) {
  722. if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
  723. MadeChange |= optimizeCallInst(CI, ModifiedDT);
  724. if (ModifiedDT)
  725. return true;
  726. }
  727. return MadeChange;
  728. }
  729. bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
  730. bool &ModifiedDT) {
  731. IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
  732. if (II) {
  733. switch (II->getIntrinsicID()) {
  734. default:
  735. break;
  736. case Intrinsic::masked_load:
  737. // Scalarize unsupported vector masked load
  738. if (TTI->isLegalMaskedLoad(CI->getType()))
  739. return false;
  740. scalarizeMaskedLoad(CI, ModifiedDT);
  741. return true;
  742. case Intrinsic::masked_store:
  743. if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
  744. return false;
  745. scalarizeMaskedStore(CI, ModifiedDT);
  746. return true;
  747. case Intrinsic::masked_gather:
  748. if (TTI->isLegalMaskedGather(CI->getType()))
  749. return false;
  750. scalarizeMaskedGather(CI, ModifiedDT);
  751. return true;
  752. case Intrinsic::masked_scatter:
  753. if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
  754. return false;
  755. scalarizeMaskedScatter(CI, ModifiedDT);
  756. return true;
  757. case Intrinsic::masked_expandload:
  758. if (TTI->isLegalMaskedExpandLoad(CI->getType()))
  759. return false;
  760. scalarizeMaskedExpandLoad(CI, ModifiedDT);
  761. return true;
  762. case Intrinsic::masked_compressstore:
  763. if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
  764. return false;
  765. scalarizeMaskedCompressStore(CI, ModifiedDT);
  766. return true;
  767. }
  768. }
  769. return false;
  770. }