LowerSwitch.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // The LowerSwitch transformation rewrites switch instructions with a sequence
  11. // of branches, which allows targets to get away with not implementing the
  12. // switch instruction until it is convenient.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/Transforms/Scalar.h"
  16. #include "llvm/ADT/STLExtras.h"
  17. #include "llvm/IR/CFG.h"
  18. #include "llvm/IR/Constants.h"
  19. #include "llvm/IR/Function.h"
  20. #include "llvm/IR/Instructions.h"
  21. #include "llvm/IR/LLVMContext.h"
  22. #include "llvm/Pass.h"
  23. #include "llvm/Support/Compiler.h"
  24. #include "llvm/Support/Debug.h"
  25. #include "llvm/Support/raw_ostream.h"
  26. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  27. #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
  28. #include <algorithm>
  29. using namespace llvm;
  30. #define DEBUG_TYPE "lower-switch"
  31. namespace {
  32. struct IntRange {
  33. int64_t Low, High;
  34. };
  35. // Return true iff R is covered by Ranges.
  36. static bool IsInRanges(const IntRange &R,
  37. const std::vector<IntRange> &Ranges) {
  38. // Note: Ranges must be sorted, non-overlapping and non-adjacent.
  39. // Find the first range whose High field is >= R.High,
  40. // then check if the Low field is <= R.Low. If so, we
  41. // have a Range that covers R.
  42. auto I = std::lower_bound(
  43. Ranges.begin(), Ranges.end(), R,
  44. [](const IntRange &A, const IntRange &B) { return A.High < B.High; });
  45. return I != Ranges.end() && I->Low <= R.Low;
  46. }
  47. /// Replace all SwitchInst instructions with chained branch instructions.
  48. class LowerSwitch : public FunctionPass {
  49. public:
  50. static char ID; // Pass identification, replacement for typeid
  51. LowerSwitch() : FunctionPass(ID) {
  52. initializeLowerSwitchPass(*PassRegistry::getPassRegistry());
  53. }
  54. bool runOnFunction(Function &F) override;
  55. struct CaseRange {
  56. ConstantInt* Low;
  57. ConstantInt* High;
  58. BasicBlock* BB;
  59. CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb)
  60. : Low(low), High(high), BB(bb) {}
  61. };
  62. typedef std::vector<CaseRange> CaseVector;
  63. typedef std::vector<CaseRange>::iterator CaseItr;
  64. private:
  65. void processSwitchInst(SwitchInst *SI, SmallPtrSetImpl<BasicBlock*> &DeleteList);
  66. BasicBlock *switchConvert(CaseItr Begin, CaseItr End,
  67. ConstantInt *LowerBound, ConstantInt *UpperBound,
  68. Value *Val, BasicBlock *Predecessor,
  69. BasicBlock *OrigBlock, BasicBlock *Default,
  70. const std::vector<IntRange> &UnreachableRanges);
  71. BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock,
  72. BasicBlock *Default);
  73. unsigned Clusterify(CaseVector &Cases, SwitchInst *SI);
  74. };
  75. /// The comparison function for sorting the switch case values in the vector.
  76. /// WARNING: Case ranges should be disjoint!
  77. struct CaseCmp {
  78. bool operator () (const LowerSwitch::CaseRange& C1,
  79. const LowerSwitch::CaseRange& C2) {
  80. const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low);
  81. const ConstantInt* CI2 = cast<const ConstantInt>(C2.High);
  82. return CI1->getValue().slt(CI2->getValue());
  83. }
  84. };
  85. }
  86. char LowerSwitch::ID = 0;
  87. INITIALIZE_PASS(LowerSwitch, "lowerswitch",
  88. "Lower SwitchInst's to branches", false, false)
  89. // Publicly exposed interface to pass...
  90. char &llvm::LowerSwitchID = LowerSwitch::ID;
  91. // createLowerSwitchPass - Interface to this file...
  92. FunctionPass *llvm::createLowerSwitchPass() {
  93. return new LowerSwitch();
  94. }
  95. bool LowerSwitch::runOnFunction(Function &F) {
  96. bool Changed = false;
  97. SmallPtrSet<BasicBlock*, 8> DeleteList;
  98. for (Function::iterator I = F.begin(), E = F.end(); I != E; ) {
  99. BasicBlock *Cur = &*I++; // Advance over block so we don't traverse new blocks
  100. // If the block is a dead Default block that will be deleted later, don't
  101. // waste time processing it.
  102. if (DeleteList.count(Cur))
  103. continue;
  104. if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) {
  105. Changed = true;
  106. processSwitchInst(SI, DeleteList);
  107. }
  108. }
  109. for (BasicBlock* BB: DeleteList) {
  110. DeleteDeadBlock(BB);
  111. }
  112. return Changed;
  113. }
  114. /// Used for debugging purposes.
  115. static raw_ostream& operator<<(raw_ostream &O,
  116. const LowerSwitch::CaseVector &C)
  117. LLVM_ATTRIBUTE_USED;
  118. static raw_ostream& operator<<(raw_ostream &O,
  119. const LowerSwitch::CaseVector &C) {
  120. O << "[";
  121. for (LowerSwitch::CaseVector::const_iterator B = C.begin(),
  122. E = C.end(); B != E; ) {
  123. O << *B->Low << " -" << *B->High;
  124. if (++B != E) O << ", ";
  125. }
  126. return O << "]";
  127. }
  128. /// \brief Update the first occurrence of the "switch statement" BB in the PHI
  129. /// node with the "new" BB. The other occurrences will:
  130. ///
  131. /// 1) Be updated by subsequent calls to this function. Switch statements may
  132. /// have more than one outcoming edge into the same BB if they all have the same
  133. /// value. When the switch statement is converted these incoming edges are now
  134. /// coming from multiple BBs.
  135. /// 2) Removed if subsequent incoming values now share the same case, i.e.,
  136. /// multiple outcome edges are condensed into one. This is necessary to keep the
  137. /// number of phi values equal to the number of branches to SuccBB.
  138. static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
  139. unsigned NumMergedCases) {
  140. for (BasicBlock::iterator I = SuccBB->begin(),
  141. IE = SuccBB->getFirstNonPHI()->getIterator();
  142. I != IE; ++I) {
  143. PHINode *PN = cast<PHINode>(I);
  144. // Only update the first occurrence.
  145. unsigned Idx = 0, E = PN->getNumIncomingValues();
  146. unsigned LocalNumMergedCases = NumMergedCases;
  147. for (; Idx != E; ++Idx) {
  148. if (PN->getIncomingBlock(Idx) == OrigBB) {
  149. PN->setIncomingBlock(Idx, NewBB);
  150. break;
  151. }
  152. }
  153. // Remove additional occurrences coming from condensed cases and keep the
  154. // number of incoming values equal to the number of branches to SuccBB.
  155. SmallVector<unsigned, 8> Indices;
  156. for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx)
  157. if (PN->getIncomingBlock(Idx) == OrigBB) {
  158. Indices.push_back(Idx);
  159. LocalNumMergedCases--;
  160. }
  161. // Remove incoming values in the reverse order to prevent invalidating
  162. // *successive* index.
  163. for (unsigned III : reverse(Indices))
  164. PN->removeIncomingValue(III);
  165. }
  166. }
  167. /// Convert the switch statement into a binary lookup of the case values.
  168. /// The function recursively builds this tree. LowerBound and UpperBound are
  169. /// used to keep track of the bounds for Val that have already been checked by
  170. /// a block emitted by one of the previous calls to switchConvert in the call
  171. /// stack.
  172. BasicBlock *
  173. LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
  174. ConstantInt *UpperBound, Value *Val,
  175. BasicBlock *Predecessor, BasicBlock *OrigBlock,
  176. BasicBlock *Default,
  177. const std::vector<IntRange> &UnreachableRanges) {
  178. unsigned Size = End - Begin;
  179. if (Size == 1) {
  180. // Check if the Case Range is perfectly squeezed in between
  181. // already checked Upper and Lower bounds. If it is then we can avoid
  182. // emitting the code that checks if the value actually falls in the range
  183. // because the bounds already tell us so.
  184. if (Begin->Low == LowerBound && Begin->High == UpperBound) {
  185. unsigned NumMergedCases = 0;
  186. if (LowerBound && UpperBound)
  187. NumMergedCases =
  188. UpperBound->getSExtValue() - LowerBound->getSExtValue();
  189. fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
  190. return Begin->BB;
  191. }
  192. return newLeafBlock(*Begin, Val, OrigBlock, Default);
  193. }
  194. unsigned Mid = Size / 2;
  195. std::vector<CaseRange> LHS(Begin, Begin + Mid);
  196. DEBUG(dbgs() << "LHS: " << LHS << "\n");
  197. std::vector<CaseRange> RHS(Begin + Mid, End);
  198. DEBUG(dbgs() << "RHS: " << RHS << "\n");
  199. CaseRange &Pivot = *(Begin + Mid);
  200. DEBUG(dbgs() << "Pivot ==> "
  201. << Pivot.Low->getValue()
  202. << " -" << Pivot.High->getValue() << "\n");
  203. // NewLowerBound here should never be the integer minimal value.
  204. // This is because it is computed from a case range that is never
  205. // the smallest, so there is always a case range that has at least
  206. // a smaller value.
  207. ConstantInt *NewLowerBound = Pivot.Low;
  208. // Because NewLowerBound is never the smallest representable integer
  209. // it is safe here to subtract one.
  210. ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(),
  211. NewLowerBound->getValue() - 1);
  212. if (!UnreachableRanges.empty()) {
  213. // Check if the gap between LHS's highest and NewLowerBound is unreachable.
  214. int64_t GapLow = LHS.back().High->getSExtValue() + 1;
  215. int64_t GapHigh = NewLowerBound->getSExtValue() - 1;
  216. IntRange Gap = { GapLow, GapHigh };
  217. if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges))
  218. NewUpperBound = LHS.back().High;
  219. }
  220. DEBUG(dbgs() << "LHS Bounds ==> ";
  221. if (LowerBound) {
  222. dbgs() << LowerBound->getSExtValue();
  223. } else {
  224. dbgs() << "NONE";
  225. }
  226. dbgs() << " - " << NewUpperBound->getSExtValue() << "\n";
  227. dbgs() << "RHS Bounds ==> ";
  228. dbgs() << NewLowerBound->getSExtValue() << " - ";
  229. if (UpperBound) {
  230. dbgs() << UpperBound->getSExtValue() << "\n";
  231. } else {
  232. dbgs() << "NONE\n";
  233. });
  234. // Create a new node that checks if the value is < pivot. Go to the
  235. // left branch if it is and right branch if not.
  236. Function* F = OrigBlock->getParent();
  237. BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock");
  238. ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT,
  239. Val, Pivot.Low, "Pivot");
  240. BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound,
  241. NewUpperBound, Val, NewNode, OrigBlock,
  242. Default, UnreachableRanges);
  243. BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound,
  244. UpperBound, Val, NewNode, OrigBlock,
  245. Default, UnreachableRanges);
  246. F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewNode);
  247. NewNode->getInstList().push_back(Comp);
  248. BranchInst::Create(LBranch, RBranch, Comp, NewNode);
  249. return NewNode;
  250. }
  251. /// Create a new leaf block for the binary lookup tree. It checks if the
  252. /// switch's value == the case's value. If not, then it jumps to the default
  253. /// branch. At this point in the tree, the value can't be another valid case
  254. /// value, so the jump to the "default" branch is warranted.
  255. BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val,
  256. BasicBlock* OrigBlock,
  257. BasicBlock* Default)
  258. {
  259. Function* F = OrigBlock->getParent();
  260. BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock");
  261. F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf);
  262. // Emit comparison
  263. ICmpInst* Comp = nullptr;
  264. if (Leaf.Low == Leaf.High) {
  265. // Make the seteq instruction...
  266. Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val,
  267. Leaf.Low, "SwitchLeaf");
  268. } else {
  269. // Make range comparison
  270. if (Leaf.Low->isMinValue(true /*isSigned*/)) {
  271. // Val >= Min && Val <= Hi --> Val <= Hi
  272. Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High,
  273. "SwitchLeaf");
  274. } else if (Leaf.Low->isZero()) {
  275. // Val >= 0 && Val <= Hi --> Val <=u Hi
  276. Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High,
  277. "SwitchLeaf");
  278. } else {
  279. // Emit V-Lo <=u Hi-Lo
  280. Constant* NegLo = ConstantExpr::getNeg(Leaf.Low);
  281. Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo,
  282. Val->getName()+".off",
  283. NewLeaf);
  284. Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High);
  285. Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound,
  286. "SwitchLeaf");
  287. }
  288. }
  289. // Make the conditional branch...
  290. BasicBlock* Succ = Leaf.BB;
  291. BranchInst::Create(Succ, Default, Comp, NewLeaf);
  292. // If there were any PHI nodes in this successor, rewrite one entry
  293. // from OrigBlock to come from NewLeaf.
  294. for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
  295. PHINode* PN = cast<PHINode>(I);
  296. // Remove all but one incoming entries from the cluster
  297. uint64_t Range = Leaf.High->getSExtValue() -
  298. Leaf.Low->getSExtValue();
  299. for (uint64_t j = 0; j < Range; ++j) {
  300. PN->removeIncomingValue(OrigBlock);
  301. }
  302. int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
  303. assert(BlockIdx != -1 && "Switch didn't go to this successor??");
  304. PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf);
  305. }
  306. return NewLeaf;
  307. }
  308. /// Transform simple list of Cases into list of CaseRange's.
  309. unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) {
  310. unsigned numCmps = 0;
  311. // Start with "simple" cases
  312. for (auto Case : SI->cases())
  313. Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(),
  314. Case.getCaseSuccessor()));
  315. std::sort(Cases.begin(), Cases.end(), CaseCmp());
  316. // Merge case into clusters
  317. if (Cases.size() >= 2) {
  318. CaseItr I = Cases.begin();
  319. for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
  320. int64_t nextValue = J->Low->getSExtValue();
  321. int64_t currentValue = I->High->getSExtValue();
  322. BasicBlock* nextBB = J->BB;
  323. BasicBlock* currentBB = I->BB;
  324. // If the two neighboring cases go to the same destination, merge them
  325. // into a single case.
  326. assert(nextValue > currentValue && "Cases should be strictly ascending");
  327. if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
  328. I->High = J->High;
  329. // FIXME: Combine branch weights.
  330. } else if (++I != J) {
  331. *I = *J;
  332. }
  333. }
  334. Cases.erase(std::next(I), Cases.end());
  335. }
  336. for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) {
  337. if (I->Low != I->High)
  338. // A range counts double, since it requires two compares.
  339. ++numCmps;
  340. }
  341. return numCmps;
  342. }
  343. /// Replace the specified switch instruction with a sequence of chained if-then
  344. /// insts in a balanced binary search.
  345. void LowerSwitch::processSwitchInst(SwitchInst *SI,
  346. SmallPtrSetImpl<BasicBlock*> &DeleteList) {
  347. BasicBlock *CurBlock = SI->getParent();
  348. BasicBlock *OrigBlock = CurBlock;
  349. Function *F = CurBlock->getParent();
  350. Value *Val = SI->getCondition(); // The value we are switching on...
  351. BasicBlock* Default = SI->getDefaultDest();
  352. // If there is only the default destination, just branch.
  353. if (!SI->getNumCases()) {
  354. BranchInst::Create(Default, CurBlock);
  355. SI->eraseFromParent();
  356. return;
  357. }
  358. // Prepare cases vector.
  359. CaseVector Cases;
  360. unsigned numCmps = Clusterify(Cases, SI);
  361. DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
  362. << ". Total compares: " << numCmps << "\n");
  363. DEBUG(dbgs() << "Cases: " << Cases << "\n");
  364. (void)numCmps;
  365. ConstantInt *LowerBound = nullptr;
  366. ConstantInt *UpperBound = nullptr;
  367. std::vector<IntRange> UnreachableRanges;
  368. if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) {
  369. // Make the bounds tightly fitted around the case value range, because we
  370. // know that the value passed to the switch must be exactly one of the case
  371. // values.
  372. assert(!Cases.empty());
  373. LowerBound = Cases.front().Low;
  374. UpperBound = Cases.back().High;
  375. DenseMap<BasicBlock *, unsigned> Popularity;
  376. unsigned MaxPop = 0;
  377. BasicBlock *PopSucc = nullptr;
  378. IntRange R = { INT64_MIN, INT64_MAX };
  379. UnreachableRanges.push_back(R);
  380. for (const auto &I : Cases) {
  381. int64_t Low = I.Low->getSExtValue();
  382. int64_t High = I.High->getSExtValue();
  383. IntRange &LastRange = UnreachableRanges.back();
  384. if (LastRange.Low == Low) {
  385. // There is nothing left of the previous range.
  386. UnreachableRanges.pop_back();
  387. } else {
  388. // Terminate the previous range.
  389. assert(Low > LastRange.Low);
  390. LastRange.High = Low - 1;
  391. }
  392. if (High != INT64_MAX) {
  393. IntRange R = { High + 1, INT64_MAX };
  394. UnreachableRanges.push_back(R);
  395. }
  396. // Count popularity.
  397. int64_t N = High - Low + 1;
  398. unsigned &Pop = Popularity[I.BB];
  399. if ((Pop += N) > MaxPop) {
  400. MaxPop = Pop;
  401. PopSucc = I.BB;
  402. }
  403. }
  404. #ifndef NDEBUG
  405. /* UnreachableRanges should be sorted and the ranges non-adjacent. */
  406. for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
  407. I != E; ++I) {
  408. assert(I->Low <= I->High);
  409. auto Next = I + 1;
  410. if (Next != E) {
  411. assert(Next->Low > I->High);
  412. }
  413. }
  414. #endif
  415. // Use the most popular block as the new default, reducing the number of
  416. // cases.
  417. assert(MaxPop > 0 && PopSucc);
  418. Default = PopSucc;
  419. Cases.erase(
  420. remove_if(Cases,
  421. [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }),
  422. Cases.end());
  423. // If there are no cases left, just branch.
  424. if (Cases.empty()) {
  425. BranchInst::Create(Default, CurBlock);
  426. SI->eraseFromParent();
  427. return;
  428. }
  429. }
  430. // Create a new, empty default block so that the new hierarchy of
  431. // if-then statements go to this and the PHI nodes are happy.
  432. BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault");
  433. F->getBasicBlockList().insert(Default->getIterator(), NewDefault);
  434. BranchInst::Create(Default, NewDefault);
  435. // If there is an entry in any PHI nodes for the default edge, make sure
  436. // to update them as well.
  437. for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) {
  438. PHINode *PN = cast<PHINode>(I);
  439. int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
  440. assert(BlockIdx != -1 && "Switch didn't go to this successor??");
  441. PN->setIncomingBlock((unsigned)BlockIdx, NewDefault);
  442. }
  443. BasicBlock *SwitchBlock =
  444. switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val,
  445. OrigBlock, OrigBlock, NewDefault, UnreachableRanges);
  446. // Branch to our shiny new if-then stuff...
  447. BranchInst::Create(SwitchBlock, OrigBlock);
  448. // We are now done with the switch instruction, delete it.
  449. BasicBlock *OldDefault = SI->getDefaultDest();
  450. CurBlock->getInstList().erase(SI);
  451. // If the Default block has no more predecessors just add it to DeleteList.
  452. if (pred_begin(OldDefault) == pred_end(OldDefault))
  453. DeleteList.insert(OldDefault);
  454. }