ExpandMemCmp.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This pass tries to expand memcmp() calls into optimally-sized loads and
  10. // compares for the target.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/ADT/Statistic.h"
  14. #include "llvm/Analysis/ConstantFolding.h"
  15. #include "llvm/Analysis/TargetLibraryInfo.h"
  16. #include "llvm/Analysis/TargetTransformInfo.h"
  17. #include "llvm/Analysis/ValueTracking.h"
  18. #include "llvm/CodeGen/TargetLowering.h"
  19. #include "llvm/CodeGen/TargetPassConfig.h"
  20. #include "llvm/CodeGen/TargetSubtargetInfo.h"
  21. #include "llvm/IR/IRBuilder.h"
  22. using namespace llvm;
  23. #define DEBUG_TYPE "expandmemcmp"
  24. STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
  25. STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
  26. STATISTIC(NumMemCmpGreaterThanMax,
  27. "Number of memcmp calls with size greater than max size");
  28. STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
  29. static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock(
  30. "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
  31. cl::desc("The number of loads per basic block for inline expansion of "
  32. "memcmp that is only being compared against zero."));
  33. static cl::opt<unsigned> MaxLoadsPerMemcmp(
  34. "max-loads-per-memcmp", cl::Hidden,
  35. cl::desc("Set maximum number of loads used in expanded memcmp"));
  36. static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize(
  37. "max-loads-per-memcmp-opt-size", cl::Hidden,
  38. cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz"));
  39. namespace {
  40. // This class provides helper functions to expand a memcmp library call into an
  41. // inline expansion.
  42. class MemCmpExpansion {
  43. struct ResultBlock {
  44. BasicBlock *BB = nullptr;
  45. PHINode *PhiSrc1 = nullptr;
  46. PHINode *PhiSrc2 = nullptr;
  47. ResultBlock() = default;
  48. };
  49. CallInst *const CI;
  50. ResultBlock ResBlock;
  51. const uint64_t Size;
  52. unsigned MaxLoadSize;
  53. uint64_t NumLoadsNonOneByte;
  54. const uint64_t NumLoadsPerBlockForZeroCmp;
  55. std::vector<BasicBlock *> LoadCmpBlocks;
  56. BasicBlock *EndBlock;
  57. PHINode *PhiRes;
  58. const bool IsUsedForZeroCmp;
  59. const DataLayout &DL;
  60. IRBuilder<> Builder;
  61. // Represents the decomposition in blocks of the expansion. For example,
  62. // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
  63. // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
  64. struct LoadEntry {
  65. LoadEntry(unsigned LoadSize, uint64_t Offset)
  66. : LoadSize(LoadSize), Offset(Offset) {
  67. }
  68. // The size of the load for this block, in bytes.
  69. unsigned LoadSize;
  70. // The offset of this load from the base pointer, in bytes.
  71. uint64_t Offset;
  72. };
  73. using LoadEntryVector = SmallVector<LoadEntry, 8>;
  74. LoadEntryVector LoadSequence;
  75. void createLoadCmpBlocks();
  76. void createResultBlock();
  77. void setupResultBlockPHINodes();
  78. void setupEndBlockPHINodes();
  79. Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
  80. void emitLoadCompareBlock(unsigned BlockIndex);
  81. void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
  82. unsigned &LoadIndex);
  83. void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes);
  84. void emitMemCmpResultBlock();
  85. Value *getMemCmpExpansionZeroCase();
  86. Value *getMemCmpEqZeroOneBlock();
  87. Value *getMemCmpOneBlock();
  88. Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType,
  89. uint64_t OffsetBytes);
  90. static LoadEntryVector
  91. computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
  92. unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte);
  93. static LoadEntryVector
  94. computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize,
  95. unsigned MaxNumLoads,
  96. unsigned &NumLoadsNonOneByte);
  97. public:
  98. MemCmpExpansion(CallInst *CI, uint64_t Size,
  99. const TargetTransformInfo::MemCmpExpansionOptions &Options,
  100. const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout);
  101. unsigned getNumBlocks();
  102. uint64_t getNumLoads() const { return LoadSequence.size(); }
  103. Value *getMemCmpExpansion();
  104. };
  105. MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence(
  106. uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
  107. const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) {
  108. NumLoadsNonOneByte = 0;
  109. LoadEntryVector LoadSequence;
  110. uint64_t Offset = 0;
  111. while (Size && !LoadSizes.empty()) {
  112. const unsigned LoadSize = LoadSizes.front();
  113. const uint64_t NumLoadsForThisSize = Size / LoadSize;
  114. if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
  115. // Do not expand if the total number of loads is larger than what the
  116. // target allows. Note that it's important that we exit before completing
  117. // the expansion to avoid using a ton of memory to store the expansion for
  118. // large sizes.
  119. return {};
  120. }
  121. if (NumLoadsForThisSize > 0) {
  122. for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
  123. LoadSequence.push_back({LoadSize, Offset});
  124. Offset += LoadSize;
  125. }
  126. if (LoadSize > 1)
  127. ++NumLoadsNonOneByte;
  128. Size = Size % LoadSize;
  129. }
  130. LoadSizes = LoadSizes.drop_front();
  131. }
  132. return LoadSequence;
  133. }
  134. MemCmpExpansion::LoadEntryVector
  135. MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
  136. const unsigned MaxLoadSize,
  137. const unsigned MaxNumLoads,
  138. unsigned &NumLoadsNonOneByte) {
  139. // These are already handled by the greedy approach.
  140. if (Size < 2 || MaxLoadSize < 2)
  141. return {};
  142. // We try to do as many non-overlapping loads as possible starting from the
  143. // beginning.
  144. const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize;
  145. assert(NumNonOverlappingLoads && "there must be at least one load");
  146. // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with
  147. // an overlapping load.
  148. Size = Size - NumNonOverlappingLoads * MaxLoadSize;
  149. // Bail if we do not need an overloapping store, this is already handled by
  150. // the greedy approach.
  151. if (Size == 0)
  152. return {};
  153. // Bail if the number of loads (non-overlapping + potential overlapping one)
  154. // is larger than the max allowed.
  155. if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
  156. return {};
  157. // Add non-overlapping loads.
  158. LoadEntryVector LoadSequence;
  159. uint64_t Offset = 0;
  160. for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) {
  161. LoadSequence.push_back({MaxLoadSize, Offset});
  162. Offset += MaxLoadSize;
  163. }
  164. // Add the last overlapping load.
  165. assert(Size > 0 && Size < MaxLoadSize && "broken invariant");
  166. LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)});
  167. NumLoadsNonOneByte = 1;
  168. return LoadSequence;
  169. }
  170. // Initialize the basic block structure required for expansion of memcmp call
  171. // with given maximum load size and memcmp size parameter.
  172. // This structure includes:
  173. // 1. A list of load compare blocks - LoadCmpBlocks.
  174. // 2. An EndBlock, split from original instruction point, which is the block to
  175. // return from.
  176. // 3. ResultBlock, block to branch to for early exit when a
  177. // LoadCmpBlock finds a difference.
  178. MemCmpExpansion::MemCmpExpansion(
  179. CallInst *const CI, uint64_t Size,
  180. const TargetTransformInfo::MemCmpExpansionOptions &Options,
  181. const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout)
  182. : CI(CI), Size(Size), MaxLoadSize(0), NumLoadsNonOneByte(0),
  183. NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock),
  184. IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), Builder(CI) {
  185. assert(Size > 0 && "zero blocks");
  186. // Scale the max size down if the target can load more bytes than we need.
  187. llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes);
  188. while (!LoadSizes.empty() && LoadSizes.front() > Size) {
  189. LoadSizes = LoadSizes.drop_front();
  190. }
  191. assert(!LoadSizes.empty() && "cannot load Size bytes");
  192. MaxLoadSize = LoadSizes.front();
  193. // Compute the decomposition.
  194. unsigned GreedyNumLoadsNonOneByte = 0;
  195. LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads,
  196. GreedyNumLoadsNonOneByte);
  197. NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
  198. assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
  199. // If we allow overlapping loads and the load sequence is not already optimal,
  200. // use overlapping loads.
  201. if (Options.AllowOverlappingLoads &&
  202. (LoadSequence.empty() || LoadSequence.size() > 2)) {
  203. unsigned OverlappingNumLoadsNonOneByte = 0;
  204. auto OverlappingLoads = computeOverlappingLoadSequence(
  205. Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte);
  206. if (!OverlappingLoads.empty() &&
  207. (LoadSequence.empty() ||
  208. OverlappingLoads.size() < LoadSequence.size())) {
  209. LoadSequence = OverlappingLoads;
  210. NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
  211. }
  212. }
  213. assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
  214. }
  215. unsigned MemCmpExpansion::getNumBlocks() {
  216. if (IsUsedForZeroCmp)
  217. return getNumLoads() / NumLoadsPerBlockForZeroCmp +
  218. (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
  219. return getNumLoads();
  220. }
  221. void MemCmpExpansion::createLoadCmpBlocks() {
  222. for (unsigned i = 0; i < getNumBlocks(); i++) {
  223. BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
  224. EndBlock->getParent(), EndBlock);
  225. LoadCmpBlocks.push_back(BB);
  226. }
  227. }
  228. void MemCmpExpansion::createResultBlock() {
  229. ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
  230. EndBlock->getParent(), EndBlock);
  231. }
  232. /// Return a pointer to an element of type `LoadSizeType` at offset
  233. /// `OffsetBytes`.
  234. Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
  235. Type *LoadSizeType,
  236. uint64_t OffsetBytes) {
  237. if (OffsetBytes > 0) {
  238. auto *ByteType = Type::getInt8Ty(CI->getContext());
  239. Source = Builder.CreateGEP(
  240. ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()),
  241. ConstantInt::get(ByteType, OffsetBytes));
  242. }
  243. return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo());
  244. }
  245. // This function creates the IR instructions for loading and comparing 1 byte.
  246. // It loads 1 byte from each source of the memcmp parameters with the given
  247. // GEPIndex. It then subtracts the two loaded values and adds this result to the
  248. // final phi node for selecting the memcmp result.
  249. void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
  250. unsigned OffsetBytes) {
  251. Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
  252. Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
  253. Value *Source1 =
  254. getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes);
  255. Value *Source2 =
  256. getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes);
  257. Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
  258. Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
  259. LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
  260. LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
  261. Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
  262. PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
  263. if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
  264. // Early exit branch if difference found to EndBlock. Otherwise, continue to
  265. // next LoadCmpBlock,
  266. Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
  267. ConstantInt::get(Diff->getType(), 0));
  268. BranchInst *CmpBr =
  269. BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
  270. Builder.Insert(CmpBr);
  271. } else {
  272. // The last block has an unconditional branch to EndBlock.
  273. BranchInst *CmpBr = BranchInst::Create(EndBlock);
  274. Builder.Insert(CmpBr);
  275. }
  276. }
  277. /// Generate an equality comparison for one or more pairs of loaded values.
  278. /// This is used in the case where the memcmp() call is compared equal or not
  279. /// equal to zero.
  280. Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
  281. unsigned &LoadIndex) {
  282. assert(LoadIndex < getNumLoads() &&
  283. "getCompareLoadPairs() called with no remaining loads");
  284. std::vector<Value *> XorList, OrList;
  285. Value *Diff = nullptr;
  286. const unsigned NumLoads =
  287. std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
  288. // For a single-block expansion, start inserting before the memcmp call.
  289. if (LoadCmpBlocks.empty())
  290. Builder.SetInsertPoint(CI);
  291. else
  292. Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
  293. Value *Cmp = nullptr;
  294. // If we have multiple loads per block, we need to generate a composite
  295. // comparison using xor+or. The type for the combinations is the largest load
  296. // type.
  297. IntegerType *const MaxLoadType =
  298. NumLoads == 1 ? nullptr
  299. : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
  300. for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
  301. const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
  302. IntegerType *LoadSizeType =
  303. IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
  304. Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
  305. CurLoadEntry.Offset);
  306. Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
  307. CurLoadEntry.Offset);
  308. // Get a constant or load a value for each source address.
  309. Value *LoadSrc1 = nullptr;
  310. if (auto *Source1C = dyn_cast<Constant>(Source1))
  311. LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
  312. if (!LoadSrc1)
  313. LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
  314. Value *LoadSrc2 = nullptr;
  315. if (auto *Source2C = dyn_cast<Constant>(Source2))
  316. LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
  317. if (!LoadSrc2)
  318. LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
  319. if (NumLoads != 1) {
  320. if (LoadSizeType != MaxLoadType) {
  321. LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
  322. LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
  323. }
  324. // If we have multiple loads per block, we need to generate a composite
  325. // comparison using xor+or.
  326. Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
  327. Diff = Builder.CreateZExt(Diff, MaxLoadType);
  328. XorList.push_back(Diff);
  329. } else {
  330. // If there's only one load per block, we just compare the loaded values.
  331. Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
  332. }
  333. }
  334. auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
  335. std::vector<Value *> OutList;
  336. for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
  337. Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
  338. OutList.push_back(Or);
  339. }
  340. if (InList.size() % 2 != 0)
  341. OutList.push_back(InList.back());
  342. return OutList;
  343. };
  344. if (!Cmp) {
  345. // Pairwise OR the XOR results.
  346. OrList = pairWiseOr(XorList);
  347. // Pairwise OR the OR results until one result left.
  348. while (OrList.size() != 1) {
  349. OrList = pairWiseOr(OrList);
  350. }
  351. assert(Diff && "Failed to find comparison diff");
  352. Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
  353. }
  354. return Cmp;
  355. }
  356. void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
  357. unsigned &LoadIndex) {
  358. Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
  359. BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
  360. ? EndBlock
  361. : LoadCmpBlocks[BlockIndex + 1];
  362. // Early exit branch if difference found to ResultBlock. Otherwise,
  363. // continue to next LoadCmpBlock or EndBlock.
  364. BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
  365. Builder.Insert(CmpBr);
  366. // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
  367. // since early exit to ResultBlock was not taken (no difference was found in
  368. // any of the bytes).
  369. if (BlockIndex == LoadCmpBlocks.size() - 1) {
  370. Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
  371. PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
  372. }
  373. }
  374. // This function creates the IR intructions for loading and comparing using the
  375. // given LoadSize. It loads the number of bytes specified by LoadSize from each
  376. // source of the memcmp parameters. It then does a subtract to see if there was
  377. // a difference in the loaded values. If a difference is found, it branches
  378. // with an early exit to the ResultBlock for calculating which source was
  379. // larger. Otherwise, it falls through to the either the next LoadCmpBlock or
  380. // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
  381. // a special case through emitLoadCompareByteBlock. The special handling can
  382. // simply subtract the loaded values and add it to the result phi node.
  383. void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
  384. // There is one load per block in this case, BlockIndex == LoadIndex.
  385. const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
  386. if (CurLoadEntry.LoadSize == 1) {
  387. MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
  388. return;
  389. }
  390. Type *LoadSizeType =
  391. IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
  392. Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
  393. assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
  394. Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
  395. Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
  396. CurLoadEntry.Offset);
  397. Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
  398. CurLoadEntry.Offset);
  399. // Load LoadSizeType from the base address.
  400. Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
  401. Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
  402. if (DL.isLittleEndian()) {
  403. Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
  404. Intrinsic::bswap, LoadSizeType);
  405. LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
  406. LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
  407. }
  408. if (LoadSizeType != MaxLoadType) {
  409. LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
  410. LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
  411. }
  412. // Add the loaded values to the phi nodes for calculating memcmp result only
  413. // if result is not used in a zero equality.
  414. if (!IsUsedForZeroCmp) {
  415. ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
  416. ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
  417. }
  418. Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
  419. BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
  420. ? EndBlock
  421. : LoadCmpBlocks[BlockIndex + 1];
  422. // Early exit branch if difference found to ResultBlock. Otherwise, continue
  423. // to next LoadCmpBlock or EndBlock.
  424. BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
  425. Builder.Insert(CmpBr);
  426. // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
  427. // since early exit to ResultBlock was not taken (no difference was found in
  428. // any of the bytes).
  429. if (BlockIndex == LoadCmpBlocks.size() - 1) {
  430. Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
  431. PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
  432. }
  433. }
  434. // This function populates the ResultBlock with a sequence to calculate the
  435. // memcmp result. It compares the two loaded source values and returns -1 if
  436. // src1 < src2 and 1 if src1 > src2.
  437. void MemCmpExpansion::emitMemCmpResultBlock() {
  438. // Special case: if memcmp result is used in a zero equality, result does not
  439. // need to be calculated and can simply return 1.
  440. if (IsUsedForZeroCmp) {
  441. BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
  442. Builder.SetInsertPoint(ResBlock.BB, InsertPt);
  443. Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
  444. PhiRes->addIncoming(Res, ResBlock.BB);
  445. BranchInst *NewBr = BranchInst::Create(EndBlock);
  446. Builder.Insert(NewBr);
  447. return;
  448. }
  449. BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
  450. Builder.SetInsertPoint(ResBlock.BB, InsertPt);
  451. Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
  452. ResBlock.PhiSrc2);
  453. Value *Res =
  454. Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
  455. ConstantInt::get(Builder.getInt32Ty(), 1));
  456. BranchInst *NewBr = BranchInst::Create(EndBlock);
  457. Builder.Insert(NewBr);
  458. PhiRes->addIncoming(Res, ResBlock.BB);
  459. }
  460. void MemCmpExpansion::setupResultBlockPHINodes() {
  461. Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
  462. Builder.SetInsertPoint(ResBlock.BB);
  463. // Note: this assumes one load per block.
  464. ResBlock.PhiSrc1 =
  465. Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
  466. ResBlock.PhiSrc2 =
  467. Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
  468. }
  469. void MemCmpExpansion::setupEndBlockPHINodes() {
  470. Builder.SetInsertPoint(&EndBlock->front());
  471. PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
  472. }
  473. Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
  474. unsigned LoadIndex = 0;
  475. // This loop populates each of the LoadCmpBlocks with the IR sequence to
  476. // handle multiple loads per block.
  477. for (unsigned I = 0; I < getNumBlocks(); ++I) {
  478. emitLoadCompareBlockMultipleLoads(I, LoadIndex);
  479. }
  480. emitMemCmpResultBlock();
  481. return PhiRes;
  482. }
  483. /// A memcmp expansion that compares equality with 0 and only has one block of
  484. /// load and compare can bypass the compare, branch, and phi IR that is required
  485. /// in the general case.
  486. Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
  487. unsigned LoadIndex = 0;
  488. Value *Cmp = getCompareLoadPairs(0, LoadIndex);
  489. assert(LoadIndex == getNumLoads() && "some entries were not consumed");
  490. return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
  491. }
  492. /// A memcmp expansion that only has one block of load and compare can bypass
  493. /// the compare, branch, and phi IR that is required in the general case.
  494. Value *MemCmpExpansion::getMemCmpOneBlock() {
  495. Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
  496. Value *Source1 = CI->getArgOperand(0);
  497. Value *Source2 = CI->getArgOperand(1);
  498. // Cast source to LoadSizeType*.
  499. if (Source1->getType() != LoadSizeType)
  500. Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
  501. if (Source2->getType() != LoadSizeType)
  502. Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
  503. // Load LoadSizeType from the base address.
  504. Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
  505. Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
  506. if (DL.isLittleEndian() && Size != 1) {
  507. Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
  508. Intrinsic::bswap, LoadSizeType);
  509. LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
  510. LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
  511. }
  512. if (Size < 4) {
  513. // The i8 and i16 cases don't need compares. We zext the loaded values and
  514. // subtract them to get the suitable negative, zero, or positive i32 result.
  515. LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
  516. LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
  517. return Builder.CreateSub(LoadSrc1, LoadSrc2);
  518. }
  519. // The result of memcmp is negative, zero, or positive, so produce that by
  520. // subtracting 2 extended compare bits: sub (ugt, ult).
  521. // If a target prefers to use selects to get -1/0/1, they should be able
  522. // to transform this later. The inverse transform (going from selects to math)
  523. // may not be possible in the DAG because the selects got converted into
  524. // branches before we got there.
  525. Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
  526. Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
  527. Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
  528. Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
  529. return Builder.CreateSub(ZextUGT, ZextULT);
  530. }
  531. // This function expands the memcmp call into an inline expansion and returns
  532. // the memcmp result.
  533. Value *MemCmpExpansion::getMemCmpExpansion() {
  534. // Create the basic block framework for a multi-block expansion.
  535. if (getNumBlocks() != 1) {
  536. BasicBlock *StartBlock = CI->getParent();
  537. EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
  538. setupEndBlockPHINodes();
  539. createResultBlock();
  540. // If return value of memcmp is not used in a zero equality, we need to
  541. // calculate which source was larger. The calculation requires the
  542. // two loaded source values of each load compare block.
  543. // These will be saved in the phi nodes created by setupResultBlockPHINodes.
  544. if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
  545. // Create the number of required load compare basic blocks.
  546. createLoadCmpBlocks();
  547. // Update the terminator added by splitBasicBlock to branch to the first
  548. // LoadCmpBlock.
  549. StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
  550. }
  551. Builder.SetCurrentDebugLocation(CI->getDebugLoc());
  552. if (IsUsedForZeroCmp)
  553. return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
  554. : getMemCmpExpansionZeroCase();
  555. if (getNumBlocks() == 1)
  556. return getMemCmpOneBlock();
  557. for (unsigned I = 0; I < getNumBlocks(); ++I) {
  558. emitLoadCompareBlock(I);
  559. }
  560. emitMemCmpResultBlock();
  561. return PhiRes;
  562. }
  563. // This function checks to see if an expansion of memcmp can be generated.
  564. // It checks for constant compare size that is less than the max inline size.
  565. // If an expansion cannot occur, returns false to leave as a library call.
  566. // Otherwise, the library call is replaced with a new IR instruction sequence.
  567. /// We want to transform:
  568. /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
  569. /// To:
  570. /// loadbb:
  571. /// %0 = bitcast i32* %buffer2 to i8*
  572. /// %1 = bitcast i32* %buffer1 to i8*
  573. /// %2 = bitcast i8* %1 to i64*
  574. /// %3 = bitcast i8* %0 to i64*
  575. /// %4 = load i64, i64* %2
  576. /// %5 = load i64, i64* %3
  577. /// %6 = call i64 @llvm.bswap.i64(i64 %4)
  578. /// %7 = call i64 @llvm.bswap.i64(i64 %5)
  579. /// %8 = sub i64 %6, %7
  580. /// %9 = icmp ne i64 %8, 0
  581. /// br i1 %9, label %res_block, label %loadbb1
  582. /// res_block: ; preds = %loadbb2,
  583. /// %loadbb1, %loadbb
  584. /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
  585. /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
  586. /// %10 = icmp ult i64 %phi.src1, %phi.src2
  587. /// %11 = select i1 %10, i32 -1, i32 1
  588. /// br label %endblock
  589. /// loadbb1: ; preds = %loadbb
  590. /// %12 = bitcast i32* %buffer2 to i8*
  591. /// %13 = bitcast i32* %buffer1 to i8*
  592. /// %14 = bitcast i8* %13 to i32*
  593. /// %15 = bitcast i8* %12 to i32*
  594. /// %16 = getelementptr i32, i32* %14, i32 2
  595. /// %17 = getelementptr i32, i32* %15, i32 2
  596. /// %18 = load i32, i32* %16
  597. /// %19 = load i32, i32* %17
  598. /// %20 = call i32 @llvm.bswap.i32(i32 %18)
  599. /// %21 = call i32 @llvm.bswap.i32(i32 %19)
  600. /// %22 = zext i32 %20 to i64
  601. /// %23 = zext i32 %21 to i64
  602. /// %24 = sub i64 %22, %23
  603. /// %25 = icmp ne i64 %24, 0
  604. /// br i1 %25, label %res_block, label %loadbb2
  605. /// loadbb2: ; preds = %loadbb1
  606. /// %26 = bitcast i32* %buffer2 to i8*
  607. /// %27 = bitcast i32* %buffer1 to i8*
  608. /// %28 = bitcast i8* %27 to i16*
  609. /// %29 = bitcast i8* %26 to i16*
  610. /// %30 = getelementptr i16, i16* %28, i16 6
  611. /// %31 = getelementptr i16, i16* %29, i16 6
  612. /// %32 = load i16, i16* %30
  613. /// %33 = load i16, i16* %31
  614. /// %34 = call i16 @llvm.bswap.i16(i16 %32)
  615. /// %35 = call i16 @llvm.bswap.i16(i16 %33)
  616. /// %36 = zext i16 %34 to i64
  617. /// %37 = zext i16 %35 to i64
  618. /// %38 = sub i64 %36, %37
  619. /// %39 = icmp ne i64 %38, 0
  620. /// br i1 %39, label %res_block, label %loadbb3
  621. /// loadbb3: ; preds = %loadbb2
  622. /// %40 = bitcast i32* %buffer2 to i8*
  623. /// %41 = bitcast i32* %buffer1 to i8*
  624. /// %42 = getelementptr i8, i8* %41, i8 14
  625. /// %43 = getelementptr i8, i8* %40, i8 14
  626. /// %44 = load i8, i8* %42
  627. /// %45 = load i8, i8* %43
  628. /// %46 = zext i8 %44 to i32
  629. /// %47 = zext i8 %45 to i32
  630. /// %48 = sub i32 %46, %47
  631. /// br label %endblock
  632. /// endblock: ; preds = %res_block,
  633. /// %loadbb3
  634. /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
  635. /// ret i32 %phi.res
  636. static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
  637. const TargetLowering *TLI, const DataLayout *DL) {
  638. NumMemCmpCalls++;
  639. // Early exit from expansion if -Oz.
  640. if (CI->getFunction()->hasMinSize())
  641. return false;
  642. // Early exit from expansion if size is not a constant.
  643. ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
  644. if (!SizeCast) {
  645. NumMemCmpNotConstant++;
  646. return false;
  647. }
  648. const uint64_t SizeVal = SizeCast->getZExtValue();
  649. if (SizeVal == 0) {
  650. return false;
  651. }
  652. // TTI call to check if target would like to expand memcmp. Also, get the
  653. // available load sizes.
  654. const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
  655. auto Options = TTI->enableMemCmpExpansion(CI->getFunction()->hasOptSize(),
  656. IsUsedForZeroCmp);
  657. if (!Options) return false;
  658. if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences())
  659. Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock;
  660. if (CI->getFunction()->hasOptSize() &&
  661. MaxLoadsPerMemcmpOptSize.getNumOccurrences())
  662. Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize;
  663. if (!CI->getFunction()->hasOptSize() && MaxLoadsPerMemcmp.getNumOccurrences())
  664. Options.MaxNumLoads = MaxLoadsPerMemcmp;
  665. MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL);
  666. // Don't expand if this will require more loads than desired by the target.
  667. if (Expansion.getNumLoads() == 0) {
  668. NumMemCmpGreaterThanMax++;
  669. return false;
  670. }
  671. NumMemCmpInlined++;
  672. Value *Res = Expansion.getMemCmpExpansion();
  673. // Replace call with result of expansion and erase call.
  674. CI->replaceAllUsesWith(Res);
  675. CI->eraseFromParent();
  676. return true;
  677. }
  678. class ExpandMemCmpPass : public FunctionPass {
  679. public:
  680. static char ID;
  681. ExpandMemCmpPass() : FunctionPass(ID) {
  682. initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
  683. }
  684. bool runOnFunction(Function &F) override {
  685. if (skipFunction(F)) return false;
  686. auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
  687. if (!TPC) {
  688. return false;
  689. }
  690. const TargetLowering* TL =
  691. TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
  692. const TargetLibraryInfo *TLI =
  693. &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  694. const TargetTransformInfo *TTI =
  695. &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
  696. auto PA = runImpl(F, TLI, TTI, TL);
  697. return !PA.areAllPreserved();
  698. }
  699. private:
  700. void getAnalysisUsage(AnalysisUsage &AU) const override {
  701. AU.addRequired<TargetLibraryInfoWrapperPass>();
  702. AU.addRequired<TargetTransformInfoWrapperPass>();
  703. FunctionPass::getAnalysisUsage(AU);
  704. }
  705. PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
  706. const TargetTransformInfo *TTI,
  707. const TargetLowering* TL);
  708. // Returns true if a change was made.
  709. bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
  710. const TargetTransformInfo *TTI, const TargetLowering* TL,
  711. const DataLayout& DL);
  712. };
  713. bool ExpandMemCmpPass::runOnBlock(
  714. BasicBlock &BB, const TargetLibraryInfo *TLI,
  715. const TargetTransformInfo *TTI, const TargetLowering* TL,
  716. const DataLayout& DL) {
  717. for (Instruction& I : BB) {
  718. CallInst *CI = dyn_cast<CallInst>(&I);
  719. if (!CI) {
  720. continue;
  721. }
  722. LibFunc Func;
  723. if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
  724. (Func == LibFunc_memcmp || Func == LibFunc_bcmp) &&
  725. expandMemCmp(CI, TTI, TL, &DL)) {
  726. return true;
  727. }
  728. }
  729. return false;
  730. }
  731. PreservedAnalyses ExpandMemCmpPass::runImpl(
  732. Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
  733. const TargetLowering* TL) {
  734. const DataLayout& DL = F.getParent()->getDataLayout();
  735. bool MadeChanges = false;
  736. for (auto BBIt = F.begin(); BBIt != F.end();) {
  737. if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
  738. MadeChanges = true;
  739. // If changes were made, restart the function from the beginning, since
  740. // the structure of the function was changed.
  741. BBIt = F.begin();
  742. } else {
  743. ++BBIt;
  744. }
  745. }
  746. return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
  747. }
  748. } // namespace
  749. char ExpandMemCmpPass::ID = 0;
  750. INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
  751. "Expand memcmp() to load/stores", false, false)
  752. INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
  753. INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
  754. INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
  755. "Expand memcmp() to load/stores", false, false)
  756. FunctionPass *llvm::createExpandMemCmpPass() {
  757. return new ExpandMemCmpPass();
  758. }