SwitchLoweringUtils.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. //===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file contains switch inst lowering optimizations and utilities for
  10. // codegen, so that it can be used for both SelectionDAG and GlobalISel.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/CodeGen/MachineJumpTableInfo.h"
  14. #include "llvm/CodeGen/SwitchLoweringUtils.h"
  15. using namespace llvm;
  16. using namespace SwitchCG;
  17. uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
  18. unsigned First, unsigned Last) {
  19. assert(Last >= First);
  20. const APInt &LowCase = Clusters[First].Low->getValue();
  21. const APInt &HighCase = Clusters[Last].High->getValue();
  22. assert(LowCase.getBitWidth() == HighCase.getBitWidth());
  23. // FIXME: A range of consecutive cases has 100% density, but only requires one
  24. // comparison to lower. We should discriminate against such consecutive ranges
  25. // in jump tables.
  26. return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
  27. }
  28. uint64_t
  29. SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
  30. unsigned First, unsigned Last) {
  31. assert(Last >= First);
  32. assert(TotalCases[Last] >= TotalCases[First]);
  33. uint64_t NumCases =
  34. TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
  35. return NumCases;
  36. }
  37. void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
  38. const SwitchInst *SI,
  39. MachineBasicBlock *DefaultMBB) {
  40. #ifndef NDEBUG
  41. // Clusters must be non-empty, sorted, and only contain Range clusters.
  42. assert(!Clusters.empty());
  43. for (CaseCluster &C : Clusters)
  44. assert(C.Kind == CC_Range);
  45. for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
  46. assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
  47. #endif
  48. if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
  49. return;
  50. const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
  51. const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
  52. // Bail if not enough cases.
  53. const int64_t N = Clusters.size();
  54. if (N < 2 || N < MinJumpTableEntries)
  55. return;
  56. // Accumulated number of cases in each cluster and those prior to it.
  57. SmallVector<unsigned, 8> TotalCases(N);
  58. for (unsigned i = 0; i < N; ++i) {
  59. const APInt &Hi = Clusters[i].High->getValue();
  60. const APInt &Lo = Clusters[i].Low->getValue();
  61. TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
  62. if (i != 0)
  63. TotalCases[i] += TotalCases[i - 1];
  64. }
  65. uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
  66. uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
  67. assert(NumCases < UINT64_MAX / 100);
  68. assert(Range >= NumCases);
  69. // Cheap case: the whole range may be suitable for jump table.
  70. if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
  71. CaseCluster JTCluster;
  72. if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
  73. Clusters[0] = JTCluster;
  74. Clusters.resize(1);
  75. return;
  76. }
  77. }
  78. // The algorithm below is not suitable for -O0.
  79. if (TM->getOptLevel() == CodeGenOpt::None)
  80. return;
  81. // Split Clusters into minimum number of dense partitions. The algorithm uses
  82. // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
  83. // for the Case Statement'" (1994), but builds the MinPartitions array in
  84. // reverse order to make it easier to reconstruct the partitions in ascending
  85. // order. In the choice between two optimal partitionings, it picks the one
  86. // which yields more jump tables.
  87. // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
  88. SmallVector<unsigned, 8> MinPartitions(N);
  89. // LastElement[i] is the last element of the partition starting at i.
  90. SmallVector<unsigned, 8> LastElement(N);
  91. // PartitionsScore[i] is used to break ties when choosing between two
  92. // partitionings resulting in the same number of partitions.
  93. SmallVector<unsigned, 8> PartitionsScore(N);
  94. // For PartitionsScore, a small number of comparisons is considered as good as
  95. // a jump table and a single comparison is considered better than a jump
  96. // table.
  97. enum PartitionScores : unsigned {
  98. NoTable = 0,
  99. Table = 1,
  100. FewCases = 1,
  101. SingleCase = 2
  102. };
  103. // Base case: There is only one way to partition Clusters[N-1].
  104. MinPartitions[N - 1] = 1;
  105. LastElement[N - 1] = N - 1;
  106. PartitionsScore[N - 1] = PartitionScores::SingleCase;
  107. // Note: loop indexes are signed to avoid underflow.
  108. for (int64_t i = N - 2; i >= 0; i--) {
  109. // Find optimal partitioning of Clusters[i..N-1].
  110. // Baseline: Put Clusters[i] into a partition on its own.
  111. MinPartitions[i] = MinPartitions[i + 1] + 1;
  112. LastElement[i] = i;
  113. PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
  114. // Search for a solution that results in fewer partitions.
  115. for (int64_t j = N - 1; j > i; j--) {
  116. // Try building a partition from Clusters[i..j].
  117. Range = getJumpTableRange(Clusters, i, j);
  118. NumCases = getJumpTableNumCases(TotalCases, i, j);
  119. assert(NumCases < UINT64_MAX / 100);
  120. assert(Range >= NumCases);
  121. if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
  122. unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
  123. unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
  124. int64_t NumEntries = j - i + 1;
  125. if (NumEntries == 1)
  126. Score += PartitionScores::SingleCase;
  127. else if (NumEntries <= SmallNumberOfEntries)
  128. Score += PartitionScores::FewCases;
  129. else if (NumEntries >= MinJumpTableEntries)
  130. Score += PartitionScores::Table;
  131. // If this leads to fewer partitions, or to the same number of
  132. // partitions with better score, it is a better partitioning.
  133. if (NumPartitions < MinPartitions[i] ||
  134. (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
  135. MinPartitions[i] = NumPartitions;
  136. LastElement[i] = j;
  137. PartitionsScore[i] = Score;
  138. }
  139. }
  140. }
  141. }
  142. // Iterate over the partitions, replacing some with jump tables in-place.
  143. unsigned DstIndex = 0;
  144. for (unsigned First = 0, Last; First < N; First = Last + 1) {
  145. Last = LastElement[First];
  146. assert(Last >= First);
  147. assert(DstIndex <= First);
  148. unsigned NumClusters = Last - First + 1;
  149. CaseCluster JTCluster;
  150. if (NumClusters >= MinJumpTableEntries &&
  151. buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
  152. Clusters[DstIndex++] = JTCluster;
  153. } else {
  154. for (unsigned I = First; I <= Last; ++I)
  155. std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
  156. }
  157. }
  158. Clusters.resize(DstIndex);
  159. }
  160. bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
  161. unsigned First, unsigned Last,
  162. const SwitchInst *SI,
  163. MachineBasicBlock *DefaultMBB,
  164. CaseCluster &JTCluster) {
  165. assert(First <= Last);
  166. auto Prob = BranchProbability::getZero();
  167. unsigned NumCmps = 0;
  168. std::vector<MachineBasicBlock*> Table;
  169. DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
  170. // Initialize probabilities in JTProbs.
  171. for (unsigned I = First; I <= Last; ++I)
  172. JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
  173. for (unsigned I = First; I <= Last; ++I) {
  174. assert(Clusters[I].Kind == CC_Range);
  175. Prob += Clusters[I].Prob;
  176. const APInt &Low = Clusters[I].Low->getValue();
  177. const APInt &High = Clusters[I].High->getValue();
  178. NumCmps += (Low == High) ? 1 : 2;
  179. if (I != First) {
  180. // Fill the gap between this and the previous cluster.
  181. const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
  182. assert(PreviousHigh.slt(Low));
  183. uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
  184. for (uint64_t J = 0; J < Gap; J++)
  185. Table.push_back(DefaultMBB);
  186. }
  187. uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
  188. for (uint64_t J = 0; J < ClusterSize; ++J)
  189. Table.push_back(Clusters[I].MBB);
  190. JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
  191. }
  192. unsigned NumDests = JTProbs.size();
  193. if (TLI->isSuitableForBitTests(NumDests, NumCmps,
  194. Clusters[First].Low->getValue(),
  195. Clusters[Last].High->getValue(), *DL)) {
  196. // Clusters[First..Last] should be lowered as bit tests instead.
  197. return false;
  198. }
  199. // Create the MBB that will load from and jump through the table.
  200. // Note: We create it here, but it's not inserted into the function yet.
  201. MachineFunction *CurMF = FuncInfo.MF;
  202. MachineBasicBlock *JumpTableMBB =
  203. CurMF->CreateMachineBasicBlock(SI->getParent());
  204. // Add successors. Note: use table order for determinism.
  205. SmallPtrSet<MachineBasicBlock *, 8> Done;
  206. for (MachineBasicBlock *Succ : Table) {
  207. if (Done.count(Succ))
  208. continue;
  209. addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
  210. Done.insert(Succ);
  211. }
  212. JumpTableMBB->normalizeSuccProbs();
  213. unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
  214. ->createJumpTableIndex(Table);
  215. // Set up the jump table info.
  216. JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
  217. JumpTableHeader JTH(Clusters[First].Low->getValue(),
  218. Clusters[Last].High->getValue(), SI->getCondition(),
  219. nullptr, false);
  220. JTCases.emplace_back(std::move(JTH), std::move(JT));
  221. JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
  222. JTCases.size() - 1, Prob);
  223. return true;
  224. }
  225. void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
  226. const SwitchInst *SI) {
  227. // Partition Clusters into as few subsets as possible, where each subset has a
  228. // range that fits in a machine word and has <= 3 unique destinations.
  229. #ifndef NDEBUG
  230. // Clusters must be sorted and contain Range or JumpTable clusters.
  231. assert(!Clusters.empty());
  232. assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
  233. for (const CaseCluster &C : Clusters)
  234. assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
  235. for (unsigned i = 1; i < Clusters.size(); ++i)
  236. assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
  237. #endif
  238. // The algorithm below is not suitable for -O0.
  239. if (TM->getOptLevel() == CodeGenOpt::None)
  240. return;
  241. // If target does not have legal shift left, do not emit bit tests at all.
  242. EVT PTy = TLI->getPointerTy(*DL);
  243. if (!TLI->isOperationLegal(ISD::SHL, PTy))
  244. return;
  245. int BitWidth = PTy.getSizeInBits();
  246. const int64_t N = Clusters.size();
  247. // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
  248. SmallVector<unsigned, 8> MinPartitions(N);
  249. // LastElement[i] is the last element of the partition starting at i.
  250. SmallVector<unsigned, 8> LastElement(N);
  251. // FIXME: This might not be the best algorithm for finding bit test clusters.
  252. // Base case: There is only one way to partition Clusters[N-1].
  253. MinPartitions[N - 1] = 1;
  254. LastElement[N - 1] = N - 1;
  255. // Note: loop indexes are signed to avoid underflow.
  256. for (int64_t i = N - 2; i >= 0; --i) {
  257. // Find optimal partitioning of Clusters[i..N-1].
  258. // Baseline: Put Clusters[i] into a partition on its own.
  259. MinPartitions[i] = MinPartitions[i + 1] + 1;
  260. LastElement[i] = i;
  261. // Search for a solution that results in fewer partitions.
  262. // Note: the search is limited by BitWidth, reducing time complexity.
  263. for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
  264. // Try building a partition from Clusters[i..j].
  265. // Check the range.
  266. if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
  267. Clusters[j].High->getValue(), *DL))
  268. continue;
  269. // Check nbr of destinations and cluster types.
  270. // FIXME: This works, but doesn't seem very efficient.
  271. bool RangesOnly = true;
  272. BitVector Dests(FuncInfo.MF->getNumBlockIDs());
  273. for (int64_t k = i; k <= j; k++) {
  274. if (Clusters[k].Kind != CC_Range) {
  275. RangesOnly = false;
  276. break;
  277. }
  278. Dests.set(Clusters[k].MBB->getNumber());
  279. }
  280. if (!RangesOnly || Dests.count() > 3)
  281. break;
  282. // Check if it's a better partition.
  283. unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
  284. if (NumPartitions < MinPartitions[i]) {
  285. // Found a better partition.
  286. MinPartitions[i] = NumPartitions;
  287. LastElement[i] = j;
  288. }
  289. }
  290. }
  291. // Iterate over the partitions, replacing with bit-test clusters in-place.
  292. unsigned DstIndex = 0;
  293. for (unsigned First = 0, Last; First < N; First = Last + 1) {
  294. Last = LastElement[First];
  295. assert(First <= Last);
  296. assert(DstIndex <= First);
  297. CaseCluster BitTestCluster;
  298. if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
  299. Clusters[DstIndex++] = BitTestCluster;
  300. } else {
  301. size_t NumClusters = Last - First + 1;
  302. std::memmove(&Clusters[DstIndex], &Clusters[First],
  303. sizeof(Clusters[0]) * NumClusters);
  304. DstIndex += NumClusters;
  305. }
  306. }
  307. Clusters.resize(DstIndex);
  308. }
  309. bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
  310. unsigned First, unsigned Last,
  311. const SwitchInst *SI,
  312. CaseCluster &BTCluster) {
  313. assert(First <= Last);
  314. if (First == Last)
  315. return false;
  316. BitVector Dests(FuncInfo.MF->getNumBlockIDs());
  317. unsigned NumCmps = 0;
  318. for (int64_t I = First; I <= Last; ++I) {
  319. assert(Clusters[I].Kind == CC_Range);
  320. Dests.set(Clusters[I].MBB->getNumber());
  321. NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
  322. }
  323. unsigned NumDests = Dests.count();
  324. APInt Low = Clusters[First].Low->getValue();
  325. APInt High = Clusters[Last].High->getValue();
  326. assert(Low.slt(High));
  327. if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
  328. return false;
  329. APInt LowBound;
  330. APInt CmpRange;
  331. const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
  332. assert(TLI->rangeFitsInWord(Low, High, *DL) &&
  333. "Case range must fit in bit mask!");
  334. // Check if the clusters cover a contiguous range such that no value in the
  335. // range will jump to the default statement.
  336. bool ContiguousRange = true;
  337. for (int64_t I = First + 1; I <= Last; ++I) {
  338. if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
  339. ContiguousRange = false;
  340. break;
  341. }
  342. }
  343. if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
  344. // Optimize the case where all the case values fit in a word without having
  345. // to subtract minValue. In this case, we can optimize away the subtraction.
  346. LowBound = APInt::getNullValue(Low.getBitWidth());
  347. CmpRange = High;
  348. ContiguousRange = false;
  349. } else {
  350. LowBound = Low;
  351. CmpRange = High - Low;
  352. }
  353. CaseBitsVector CBV;
  354. auto TotalProb = BranchProbability::getZero();
  355. for (unsigned i = First; i <= Last; ++i) {
  356. // Find the CaseBits for this destination.
  357. unsigned j;
  358. for (j = 0; j < CBV.size(); ++j)
  359. if (CBV[j].BB == Clusters[i].MBB)
  360. break;
  361. if (j == CBV.size())
  362. CBV.push_back(
  363. CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
  364. CaseBits *CB = &CBV[j];
  365. // Update Mask, Bits and ExtraProb.
  366. uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
  367. uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
  368. assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
  369. CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
  370. CB->Bits += Hi - Lo + 1;
  371. CB->ExtraProb += Clusters[i].Prob;
  372. TotalProb += Clusters[i].Prob;
  373. }
  374. BitTestInfo BTI;
  375. llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
  376. // Sort by probability first, number of bits second, bit mask third.
  377. if (a.ExtraProb != b.ExtraProb)
  378. return a.ExtraProb > b.ExtraProb;
  379. if (a.Bits != b.Bits)
  380. return a.Bits > b.Bits;
  381. return a.Mask < b.Mask;
  382. });
  383. for (auto &CB : CBV) {
  384. MachineBasicBlock *BitTestBB =
  385. FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
  386. BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
  387. }
  388. BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
  389. SI->getCondition(), -1U, MVT::Other, false,
  390. ContiguousRange, nullptr, nullptr, std::move(BTI),
  391. TotalProb);
  392. BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
  393. BitTestCases.size() - 1, TotalProb);
  394. return true;
  395. }
  396. void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
  397. #ifndef NDEBUG
  398. for (const CaseCluster &CC : Clusters)
  399. assert(CC.Low == CC.High && "Input clusters must be single-case");
  400. #endif
  401. llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
  402. return a.Low->getValue().slt(b.Low->getValue());
  403. });
  404. // Merge adjacent clusters with the same destination.
  405. const unsigned N = Clusters.size();
  406. unsigned DstIndex = 0;
  407. for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
  408. CaseCluster &CC = Clusters[SrcIndex];
  409. const ConstantInt *CaseVal = CC.Low;
  410. MachineBasicBlock *Succ = CC.MBB;
  411. if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
  412. (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
  413. // If this case has the same successor and is a neighbour, merge it into
  414. // the previous cluster.
  415. Clusters[DstIndex - 1].High = CaseVal;
  416. Clusters[DstIndex - 1].Prob += CC.Prob;
  417. } else {
  418. std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
  419. sizeof(Clusters[SrcIndex]));
  420. }
  421. }
  422. Clusters.resize(DstIndex);
  423. }