SwitchLoweringUtils.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file contains switch inst lowering optimizations and utilities for
  10. // codegen, so that it can be used for both SelectionDAG and GlobalISel.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/MachineJumpTableInfo.h"
  14. #include "llvm/CodeGen/SwitchLoweringUtils.h"
  15. using namespace llvm;
  16. using namespace SwitchCG;
  17. uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
  18. unsigned First, unsigned Last) {
  19. assert(Last >= First);
  20. const APInt &LowCase = Clusters[First].Low->getValue();
  21. const APInt &HighCase = Clusters[Last].High->getValue();
  22. assert(LowCase.getBitWidth() == HighCase.getBitWidth());
  23. // FIXME: A range of consecutive cases has 100% density, but only requires one
  24. // comparison to lower. We should discriminate against such consecutive ranges
  25. // in jump tables.
  26. return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
  27. }
  28. uint64_t
  29. SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
  30. unsigned First, unsigned Last) {
  31. assert(Last >= First);
  32. assert(TotalCases[Last] >= TotalCases[First]);
  33. uint64_t NumCases =
  34. TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
  35. return NumCases;
  36. }
  37. void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
  38. const SwitchInst *SI,
  39. MachineBasicBlock *DefaultMBB) {
  40. #ifndef NDEBUG
  41. // Clusters must be non-empty, sorted, and only contain Range clusters.
  42. assert(!Clusters.empty());
  43. for (CaseCluster &C : Clusters)
  44. assert(C.Kind == CC_Range);
  45. for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
  46. assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
  47. #endif
  48. assert(TLI && "TLI not set!");
  49. if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
  50. return;
  51. const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
  52. const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
  53. // Bail if not enough cases.
  54. const int64_t N = Clusters.size();
  55. if (N < 2 || N < MinJumpTableEntries)
  56. return;
  57. // Accumulated number of cases in each cluster and those prior to it.
  58. SmallVector<unsigned, 8> TotalCases(N);
  59. for (unsigned i = 0; i < N; ++i) {
  60. const APInt &Hi = Clusters[i].High->getValue();
  61. const APInt &Lo = Clusters[i].Low->getValue();
  62. TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
  63. if (i != 0)
  64. TotalCases[i] += TotalCases[i - 1];
  65. }
  66. uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
  67. uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
  68. assert(NumCases < UINT64_MAX / 100);
  69. assert(Range >= NumCases);
  70. // Cheap case: the whole range may be suitable for jump table.
  71. if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
  72. CaseCluster JTCluster;
  73. if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
  74. Clusters[0] = JTCluster;
  75. Clusters.resize(1);
  76. return;
  77. }
  78. }
  79. // The algorithm below is not suitable for -O0.
  80. if (TM->getOptLevel() == CodeGenOpt::None)
  81. return;
  82. // Split Clusters into minimum number of dense partitions. The algorithm uses
  83. // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
  84. // for the Case Statement'" (1994), but builds the MinPartitions array in
  85. // reverse order to make it easier to reconstruct the partitions in ascending
  86. // order. In the choice between two optimal partitionings, it picks the one
  87. // which yields more jump tables.
  88. // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
  89. SmallVector<unsigned, 8> MinPartitions(N);
  90. // LastElement[i] is the last element of the partition starting at i.
  91. SmallVector<unsigned, 8> LastElement(N);
  92. // PartitionsScore[i] is used to break ties when choosing between two
  93. // partitionings resulting in the same number of partitions.
  94. SmallVector<unsigned, 8> PartitionsScore(N);
  95. // For PartitionsScore, a small number of comparisons is considered as good as
  96. // a jump table and a single comparison is considered better than a jump
  97. // table.
  98. enum PartitionScores : unsigned {
  99. NoTable = 0,
  100. Table = 1,
  101. FewCases = 1,
  102. SingleCase = 2
  103. };
  104. // Base case: There is only one way to partition Clusters[N-1].
  105. MinPartitions[N - 1] = 1;
  106. LastElement[N - 1] = N - 1;
  107. PartitionsScore[N - 1] = PartitionScores::SingleCase;
  108. // Note: loop indexes are signed to avoid underflow.
  109. for (int64_t i = N - 2; i >= 0; i--) {
  110. // Find optimal partitioning of Clusters[i..N-1].
  111. // Baseline: Put Clusters[i] into a partition on its own.
  112. MinPartitions[i] = MinPartitions[i + 1] + 1;
  113. LastElement[i] = i;
  114. PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
  115. // Search for a solution that results in fewer partitions.
  116. for (int64_t j = N - 1; j > i; j--) {
  117. // Try building a partition from Clusters[i..j].
  118. Range = getJumpTableRange(Clusters, i, j);
  119. NumCases = getJumpTableNumCases(TotalCases, i, j);
  120. assert(NumCases < UINT64_MAX / 100);
  121. assert(Range >= NumCases);
  122. if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
  123. unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
  124. unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
  125. int64_t NumEntries = j - i + 1;
  126. if (NumEntries == 1)
  127. Score += PartitionScores::SingleCase;
  128. else if (NumEntries <= SmallNumberOfEntries)
  129. Score += PartitionScores::FewCases;
  130. else if (NumEntries >= MinJumpTableEntries)
  131. Score += PartitionScores::Table;
  132. // If this leads to fewer partitions, or to the same number of
  133. // partitions with better score, it is a better partitioning.
  134. if (NumPartitions < MinPartitions[i] ||
  135. (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
  136. MinPartitions[i] = NumPartitions;
  137. LastElement[i] = j;
  138. PartitionsScore[i] = Score;
  139. }
  140. }
  141. }
  142. }
  143. // Iterate over the partitions, replacing some with jump tables in-place.
  144. unsigned DstIndex = 0;
  145. for (unsigned First = 0, Last; First < N; First = Last + 1) {
  146. Last = LastElement[First];
  147. assert(Last >= First);
  148. assert(DstIndex <= First);
  149. unsigned NumClusters = Last - First + 1;
  150. CaseCluster JTCluster;
  151. if (NumClusters >= MinJumpTableEntries &&
  152. buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
  153. Clusters[DstIndex++] = JTCluster;
  154. } else {
  155. for (unsigned I = First; I <= Last; ++I)
  156. std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
  157. }
  158. }
  159. Clusters.resize(DstIndex);
  160. }
  161. bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
  162. unsigned First, unsigned Last,
  163. const SwitchInst *SI,
  164. MachineBasicBlock *DefaultMBB,
  165. CaseCluster &JTCluster) {
  166. assert(First <= Last);
  167. auto Prob = BranchProbability::getZero();
  168. unsigned NumCmps = 0;
  169. std::vector<MachineBasicBlock*> Table;
  170. DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
  171. // Initialize probabilities in JTProbs.
  172. for (unsigned I = First; I <= Last; ++I)
  173. JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
  174. for (unsigned I = First; I <= Last; ++I) {
  175. assert(Clusters[I].Kind == CC_Range);
  176. Prob += Clusters[I].Prob;
  177. const APInt &Low = Clusters[I].Low->getValue();
  178. const APInt &High = Clusters[I].High->getValue();
  179. NumCmps += (Low == High) ? 1 : 2;
  180. if (I != First) {
  181. // Fill the gap between this and the previous cluster.
  182. const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
  183. assert(PreviousHigh.slt(Low));
  184. uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
  185. for (uint64_t J = 0; J < Gap; J++)
  186. Table.push_back(DefaultMBB);
  187. }
  188. uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
  189. for (uint64_t J = 0; J < ClusterSize; ++J)
  190. Table.push_back(Clusters[I].MBB);
  191. JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
  192. }
  193. unsigned NumDests = JTProbs.size();
  194. if (TLI->isSuitableForBitTests(NumDests, NumCmps,
  195. Clusters[First].Low->getValue(),
  196. Clusters[Last].High->getValue(), *DL)) {
  197. // Clusters[First..Last] should be lowered as bit tests instead.
  198. return false;
  199. }
  200. // Create the MBB that will load from and jump through the table.
  201. // Note: We create it here, but it's not inserted into the function yet.
  202. MachineFunction *CurMF = FuncInfo.MF;
  203. MachineBasicBlock *JumpTableMBB =
  204. CurMF->CreateMachineBasicBlock(SI->getParent());
  205. // Add successors. Note: use table order for determinism.
  206. SmallPtrSet<MachineBasicBlock *, 8> Done;
  207. for (MachineBasicBlock *Succ : Table) {
  208. if (Done.count(Succ))
  209. continue;
  210. addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
  211. Done.insert(Succ);
  212. }
  213. JumpTableMBB->normalizeSuccProbs();
  214. unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
  215. ->createJumpTableIndex(Table);
  216. // Set up the jump table info.
  217. JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
  218. JumpTableHeader JTH(Clusters[First].Low->getValue(),
  219. Clusters[Last].High->getValue(), SI->getCondition(),
  220. nullptr, false);
  221. JTCases.emplace_back(std::move(JTH), std::move(JT));
  222. JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
  223. JTCases.size() - 1, Prob);
  224. return true;
  225. }
  226. void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
  227. const SwitchInst *SI) {
  228. // Partition Clusters into as few subsets as possible, where each subset has a
  229. // range that fits in a machine word and has <= 3 unique destinations.
  230. #ifndef NDEBUG
  231. // Clusters must be sorted and contain Range or JumpTable clusters.
  232. assert(!Clusters.empty());
  233. assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
  234. for (const CaseCluster &C : Clusters)
  235. assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
  236. for (unsigned i = 1; i < Clusters.size(); ++i)
  237. assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
  238. #endif
  239. // The algorithm below is not suitable for -O0.
  240. if (TM->getOptLevel() == CodeGenOpt::None)
  241. return;
  242. // If target does not have legal shift left, do not emit bit tests at all.
  243. EVT PTy = TLI->getPointerTy(*DL);
  244. if (!TLI->isOperationLegal(ISD::SHL, PTy))
  245. return;
  246. int BitWidth = PTy.getSizeInBits();
  247. const int64_t N = Clusters.size();
  248. // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
  249. SmallVector<unsigned, 8> MinPartitions(N);
  250. // LastElement[i] is the last element of the partition starting at i.
  251. SmallVector<unsigned, 8> LastElement(N);
  252. // FIXME: This might not be the best algorithm for finding bit test clusters.
  253. // Base case: There is only one way to partition Clusters[N-1].
  254. MinPartitions[N - 1] = 1;
  255. LastElement[N - 1] = N - 1;
  256. // Note: loop indexes are signed to avoid underflow.
  257. for (int64_t i = N - 2; i >= 0; --i) {
  258. // Find optimal partitioning of Clusters[i..N-1].
  259. // Baseline: Put Clusters[i] into a partition on its own.
  260. MinPartitions[i] = MinPartitions[i + 1] + 1;
  261. LastElement[i] = i;
  262. // Search for a solution that results in fewer partitions.
  263. // Note: the search is limited by BitWidth, reducing time complexity.
  264. for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
  265. // Try building a partition from Clusters[i..j].
  266. // Check the range.
  267. if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
  268. Clusters[j].High->getValue(), *DL))
  269. continue;
  270. // Check nbr of destinations and cluster types.
  271. // FIXME: This works, but doesn't seem very efficient.
  272. bool RangesOnly = true;
  273. BitVector Dests(FuncInfo.MF->getNumBlockIDs());
  274. for (int64_t k = i; k <= j; k++) {
  275. if (Clusters[k].Kind != CC_Range) {
  276. RangesOnly = false;
  277. break;
  278. }
  279. Dests.set(Clusters[k].MBB->getNumber());
  280. }
  281. if (!RangesOnly || Dests.count() > 3)
  282. break;
  283. // Check if it's a better partition.
  284. unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
  285. if (NumPartitions < MinPartitions[i]) {
  286. // Found a better partition.
  287. MinPartitions[i] = NumPartitions;
  288. LastElement[i] = j;
  289. }
  290. }
  291. }
  292. // Iterate over the partitions, replacing with bit-test clusters in-place.
  293. unsigned DstIndex = 0;
  294. for (unsigned First = 0, Last; First < N; First = Last + 1) {
  295. Last = LastElement[First];
  296. assert(First <= Last);
  297. assert(DstIndex <= First);
  298. CaseCluster BitTestCluster;
  299. if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
  300. Clusters[DstIndex++] = BitTestCluster;
  301. } else {
  302. size_t NumClusters = Last - First + 1;
  303. std::memmove(&Clusters[DstIndex], &Clusters[First],
  304. sizeof(Clusters[0]) * NumClusters);
  305. DstIndex += NumClusters;
  306. }
  307. }
  308. Clusters.resize(DstIndex);
  309. }
  310. bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
  311. unsigned First, unsigned Last,
  312. const SwitchInst *SI,
  313. CaseCluster &BTCluster) {
  314. assert(First <= Last);
  315. if (First == Last)
  316. return false;
  317. BitVector Dests(FuncInfo.MF->getNumBlockIDs());
  318. unsigned NumCmps = 0;
  319. for (int64_t I = First; I <= Last; ++I) {
  320. assert(Clusters[I].Kind == CC_Range);
  321. Dests.set(Clusters[I].MBB->getNumber());
  322. NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
  323. }
  324. unsigned NumDests = Dests.count();
  325. APInt Low = Clusters[First].Low->getValue();
  326. APInt High = Clusters[Last].High->getValue();
  327. assert(Low.slt(High));
  328. if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
  329. return false;
  330. APInt LowBound;
  331. APInt CmpRange;
  332. const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
  333. assert(TLI->rangeFitsInWord(Low, High, *DL) &&
  334. "Case range must fit in bit mask!");
  335. // Check if the clusters cover a contiguous range such that no value in the
  336. // range will jump to the default statement.
  337. bool ContiguousRange = true;
  338. for (int64_t I = First + 1; I <= Last; ++I) {
  339. if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
  340. ContiguousRange = false;
  341. break;
  342. }
  343. }
  344. if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
  345. // Optimize the case where all the case values fit in a word without having
  346. // to subtract minValue. In this case, we can optimize away the subtraction.
  347. LowBound = APInt::getNullValue(Low.getBitWidth());
  348. CmpRange = High;
  349. ContiguousRange = false;
  350. } else {
  351. LowBound = Low;
  352. CmpRange = High - Low;
  353. }
  354. CaseBitsVector CBV;
  355. auto TotalProb = BranchProbability::getZero();
  356. for (unsigned i = First; i <= Last; ++i) {
  357. // Find the CaseBits for this destination.
  358. unsigned j;
  359. for (j = 0; j < CBV.size(); ++j)
  360. if (CBV[j].BB == Clusters[i].MBB)
  361. break;
  362. if (j == CBV.size())
  363. CBV.push_back(
  364. CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
  365. CaseBits *CB = &CBV[j];
  366. // Update Mask, Bits and ExtraProb.
  367. uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
  368. uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
  369. assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
  370. CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
  371. CB->Bits += Hi - Lo + 1;
  372. CB->ExtraProb += Clusters[i].Prob;
  373. TotalProb += Clusters[i].Prob;
  374. }
  375. BitTestInfo BTI;
  376. llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
  377. // Sort by probability first, number of bits second, bit mask third.
  378. if (a.ExtraProb != b.ExtraProb)
  379. return a.ExtraProb > b.ExtraProb;
  380. if (a.Bits != b.Bits)
  381. return a.Bits > b.Bits;
  382. return a.Mask < b.Mask;
  383. });
  384. for (auto &CB : CBV) {
  385. MachineBasicBlock *BitTestBB =
  386. FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
  387. BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
  388. }
  389. BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
  390. SI->getCondition(), -1U, MVT::Other, false,
  391. ContiguousRange, nullptr, nullptr, std::move(BTI),
  392. TotalProb);
  393. BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
  394. BitTestCases.size() - 1, TotalProb);
  395. return true;
  396. }
  397. void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
  398. #ifndef NDEBUG
  399. for (const CaseCluster &CC : Clusters)
  400. assert(CC.Low == CC.High && "Input clusters must be single-case");
  401. #endif
  402. llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
  403. return a.Low->getValue().slt(b.Low->getValue());
  404. });
  405. // Merge adjacent clusters with the same destination.
  406. const unsigned N = Clusters.size();
  407. unsigned DstIndex = 0;
  408. for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
  409. CaseCluster &CC = Clusters[SrcIndex];
  410. const ConstantInt *CaseVal = CC.Low;
  411. MachineBasicBlock *Succ = CC.MBB;
  412. if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
  413. (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
  414. // If this case has the same successor and is a neighbour, merge it into
  415. // the previous cluster.
  416. Clusters[DstIndex - 1].High = CaseVal;
  417. Clusters[DstIndex - 1].Prob += CC.Prob;
  418. } else {
  419. std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
  420. sizeof(Clusters[SrcIndex]));
  421. }
  422. }
  423. Clusters.resize(DstIndex);
  424. }