CodeExtractor.cpp 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566
  1. //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements the interface to tear out a code region, such as an
  10. // individual loop or a parallel section, into a new function, replacing it with
  11. // a call to the new function.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Transforms/Utils/CodeExtractor.h"
  15. #include "llvm/ADT/ArrayRef.h"
  16. #include "llvm/ADT/DenseMap.h"
  17. #include "llvm/ADT/Optional.h"
  18. #include "llvm/ADT/STLExtras.h"
  19. #include "llvm/ADT/SetVector.h"
  20. #include "llvm/ADT/SmallPtrSet.h"
  21. #include "llvm/ADT/SmallVector.h"
  22. #include "llvm/Analysis/AssumptionCache.h"
  23. #include "llvm/Analysis/BlockFrequencyInfo.h"
  24. #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
  25. #include "llvm/Analysis/BranchProbabilityInfo.h"
  26. #include "llvm/Analysis/LoopInfo.h"
  27. #include "llvm/IR/Argument.h"
  28. #include "llvm/IR/Attributes.h"
  29. #include "llvm/IR/BasicBlock.h"
  30. #include "llvm/IR/CFG.h"
  31. #include "llvm/IR/Constant.h"
  32. #include "llvm/IR/Constants.h"
  33. #include "llvm/IR/DataLayout.h"
  34. #include "llvm/IR/DerivedTypes.h"
  35. #include "llvm/IR/Dominators.h"
  36. #include "llvm/IR/Function.h"
  37. #include "llvm/IR/GlobalValue.h"
  38. #include "llvm/IR/InstrTypes.h"
  39. #include "llvm/IR/Instruction.h"
  40. #include "llvm/IR/Instructions.h"
  41. #include "llvm/IR/IntrinsicInst.h"
  42. #include "llvm/IR/Intrinsics.h"
  43. #include "llvm/IR/LLVMContext.h"
  44. #include "llvm/IR/MDBuilder.h"
  45. #include "llvm/IR/Module.h"
  46. #include "llvm/IR/PatternMatch.h"
  47. #include "llvm/IR/Type.h"
  48. #include "llvm/IR/User.h"
  49. #include "llvm/IR/Value.h"
  50. #include "llvm/IR/Verifier.h"
  51. #include "llvm/Pass.h"
  52. #include "llvm/Support/BlockFrequency.h"
  53. #include "llvm/Support/BranchProbability.h"
  54. #include "llvm/Support/Casting.h"
  55. #include "llvm/Support/CommandLine.h"
  56. #include "llvm/Support/Debug.h"
  57. #include "llvm/Support/ErrorHandling.h"
  58. #include "llvm/Support/raw_ostream.h"
  59. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  60. #include "llvm/Transforms/Utils/Local.h"
  61. #include <cassert>
  62. #include <cstdint>
  63. #include <iterator>
  64. #include <map>
  65. #include <set>
  66. #include <utility>
  67. #include <vector>
  68. using namespace llvm;
  69. using namespace llvm::PatternMatch;
  70. using ProfileCount = Function::ProfileCount;
  71. #define DEBUG_TYPE "code-extractor"
  72. // Provide a command-line option to aggregate function arguments into a struct
  73. // for functions produced by the code extractor. This is useful when converting
  74. // extracted functions to pthread-based code, as only one argument (void*) can
  75. // be passed in to pthread_create().
  76. static cl::opt<bool>
  77. AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
  78. cl::desc("Aggregate arguments to code-extracted functions"));
  79. /// Test whether a block is valid for extraction.
  80. static bool isBlockValidForExtraction(const BasicBlock &BB,
  81. const SetVector<BasicBlock *> &Result,
  82. bool AllowVarArgs, bool AllowAlloca) {
  83. // taking the address of a basic block moved to another function is illegal
  84. if (BB.hasAddressTaken())
  85. return false;
  86. // don't hoist code that uses another basicblock address, as it's likely to
  87. // lead to unexpected behavior, like cross-function jumps
  88. SmallPtrSet<User const *, 16> Visited;
  89. SmallVector<User const *, 16> ToVisit;
  90. for (Instruction const &Inst : BB)
  91. ToVisit.push_back(&Inst);
  92. while (!ToVisit.empty()) {
  93. User const *Curr = ToVisit.pop_back_val();
  94. if (!Visited.insert(Curr).second)
  95. continue;
  96. if (isa<BlockAddress const>(Curr))
  97. return false; // even a reference to self is likely to be not compatible
  98. if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
  99. continue;
  100. for (auto const &U : Curr->operands()) {
  101. if (auto *UU = dyn_cast<User>(U))
  102. ToVisit.push_back(UU);
  103. }
  104. }
  105. // If explicitly requested, allow vastart and alloca. For invoke instructions
  106. // verify that extraction is valid.
  107. for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
  108. if (isa<AllocaInst>(I)) {
  109. if (!AllowAlloca)
  110. return false;
  111. continue;
  112. }
  113. if (const auto *II = dyn_cast<InvokeInst>(I)) {
  114. // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
  115. // must be a part of the subgraph which is being extracted.
  116. if (auto *UBB = II->getUnwindDest())
  117. if (!Result.count(UBB))
  118. return false;
  119. continue;
  120. }
  121. // All catch handlers of a catchswitch instruction as well as the unwind
  122. // destination must be in the subgraph.
  123. if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) {
  124. if (auto *UBB = CSI->getUnwindDest())
  125. if (!Result.count(UBB))
  126. return false;
  127. for (auto *HBB : CSI->handlers())
  128. if (!Result.count(const_cast<BasicBlock*>(HBB)))
  129. return false;
  130. continue;
  131. }
  132. // Make sure that entire catch handler is within subgraph. It is sufficient
  133. // to check that catch return's block is in the list.
  134. if (const auto *CPI = dyn_cast<CatchPadInst>(I)) {
  135. for (const auto *U : CPI->users())
  136. if (const auto *CRI = dyn_cast<CatchReturnInst>(U))
  137. if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
  138. return false;
  139. continue;
  140. }
  141. // And do similar checks for cleanup handler - the entire handler must be
  142. // in subgraph which is going to be extracted. For cleanup return should
  143. // additionally check that the unwind destination is also in the subgraph.
  144. if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) {
  145. for (const auto *U : CPI->users())
  146. if (const auto *CRI = dyn_cast<CleanupReturnInst>(U))
  147. if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
  148. return false;
  149. continue;
  150. }
  151. if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) {
  152. if (auto *UBB = CRI->getUnwindDest())
  153. if (!Result.count(UBB))
  154. return false;
  155. continue;
  156. }
  157. if (const CallInst *CI = dyn_cast<CallInst>(I)) {
  158. if (const Function *F = CI->getCalledFunction()) {
  159. auto IID = F->getIntrinsicID();
  160. if (IID == Intrinsic::vastart) {
  161. if (AllowVarArgs)
  162. continue;
  163. else
  164. return false;
  165. }
  166. // Currently, we miscompile outlined copies of eh_typid_for. There are
  167. // proposals for fixing this in llvm.org/PR39545.
  168. if (IID == Intrinsic::eh_typeid_for)
  169. return false;
  170. }
  171. }
  172. }
  173. return true;
  174. }
  175. /// Build a set of blocks to extract if the input blocks are viable.
  176. static SetVector<BasicBlock *>
  177. buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
  178. bool AllowVarArgs, bool AllowAlloca) {
  179. assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
  180. SetVector<BasicBlock *> Result;
  181. // Loop over the blocks, adding them to our set-vector, and aborting with an
  182. // empty set if we encounter invalid blocks.
  183. for (BasicBlock *BB : BBs) {
  184. // If this block is dead, don't process it.
  185. if (DT && !DT->isReachableFromEntry(BB))
  186. continue;
  187. if (!Result.insert(BB))
  188. llvm_unreachable("Repeated basic blocks in extraction input");
  189. }
  190. LLVM_DEBUG(dbgs() << "Region front block: " << Result.front()->getName()
  191. << '\n');
  192. for (auto *BB : Result) {
  193. if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
  194. return {};
  195. // Make sure that the first block is not a landing pad.
  196. if (BB == Result.front()) {
  197. if (BB->isEHPad()) {
  198. LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
  199. return {};
  200. }
  201. continue;
  202. }
  203. // All blocks other than the first must not have predecessors outside of
  204. // the subgraph which is being extracted.
  205. for (auto *PBB : predecessors(BB))
  206. if (!Result.count(PBB)) {
  207. LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from "
  208. "outside the region except for the first block!\n"
  209. << "Problematic source BB: " << BB->getName() << "\n"
  210. << "Problematic destination BB: " << PBB->getName()
  211. << "\n");
  212. return {};
  213. }
  214. }
  215. return Result;
  216. }
  217. CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
  218. bool AggregateArgs, BlockFrequencyInfo *BFI,
  219. BranchProbabilityInfo *BPI, AssumptionCache *AC,
  220. bool AllowVarArgs, bool AllowAlloca,
  221. std::string Suffix)
  222. : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
  223. BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs),
  224. Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
  225. Suffix(Suffix) {}
  226. CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
  227. BlockFrequencyInfo *BFI,
  228. BranchProbabilityInfo *BPI, AssumptionCache *AC,
  229. std::string Suffix)
  230. : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
  231. BPI(BPI), AC(AC), AllowVarArgs(false),
  232. Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
  233. /* AllowVarArgs */ false,
  234. /* AllowAlloca */ false)),
  235. Suffix(Suffix) {}
  236. /// definedInRegion - Return true if the specified value is defined in the
  237. /// extracted region.
  238. static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
  239. if (Instruction *I = dyn_cast<Instruction>(V))
  240. if (Blocks.count(I->getParent()))
  241. return true;
  242. return false;
  243. }
  244. /// definedInCaller - Return true if the specified value is defined in the
  245. /// function being code extracted, but not in the region being extracted.
  246. /// These values must be passed in as live-ins to the function.
  247. static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
  248. if (isa<Argument>(V)) return true;
  249. if (Instruction *I = dyn_cast<Instruction>(V))
  250. if (!Blocks.count(I->getParent()))
  251. return true;
  252. return false;
  253. }
  254. static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
  255. BasicBlock *CommonExitBlock = nullptr;
  256. auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
  257. for (auto *Succ : successors(Block)) {
  258. // Internal edges, ok.
  259. if (Blocks.count(Succ))
  260. continue;
  261. if (!CommonExitBlock) {
  262. CommonExitBlock = Succ;
  263. continue;
  264. }
  265. if (CommonExitBlock == Succ)
  266. continue;
  267. return true;
  268. }
  269. return false;
  270. };
  271. if (any_of(Blocks, hasNonCommonExitSucc))
  272. return nullptr;
  273. return CommonExitBlock;
  274. }
  275. bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
  276. Instruction *Addr) const {
  277. AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
  278. Function *Func = (*Blocks.begin())->getParent();
  279. for (BasicBlock &BB : *Func) {
  280. if (Blocks.count(&BB))
  281. continue;
  282. for (Instruction &II : BB) {
  283. if (isa<DbgInfoIntrinsic>(II))
  284. continue;
  285. unsigned Opcode = II.getOpcode();
  286. Value *MemAddr = nullptr;
  287. switch (Opcode) {
  288. case Instruction::Store:
  289. case Instruction::Load: {
  290. if (Opcode == Instruction::Store) {
  291. StoreInst *SI = cast<StoreInst>(&II);
  292. MemAddr = SI->getPointerOperand();
  293. } else {
  294. LoadInst *LI = cast<LoadInst>(&II);
  295. MemAddr = LI->getPointerOperand();
  296. }
  297. // Global variable can not be aliased with locals.
  298. if (dyn_cast<Constant>(MemAddr))
  299. break;
  300. Value *Base = MemAddr->stripInBoundsConstantOffsets();
  301. if (!isa<AllocaInst>(Base) || Base == AI)
  302. return false;
  303. break;
  304. }
  305. default: {
  306. IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
  307. if (IntrInst) {
  308. if (IntrInst->isLifetimeStartOrEnd())
  309. break;
  310. return false;
  311. }
  312. // Treat all the other cases conservatively if it has side effects.
  313. if (II.mayHaveSideEffects())
  314. return false;
  315. }
  316. }
  317. }
  318. }
  319. return true;
  320. }
  321. BasicBlock *
  322. CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
  323. BasicBlock *SinglePredFromOutlineRegion = nullptr;
  324. assert(!Blocks.count(CommonExitBlock) &&
  325. "Expect a block outside the region!");
  326. for (auto *Pred : predecessors(CommonExitBlock)) {
  327. if (!Blocks.count(Pred))
  328. continue;
  329. if (!SinglePredFromOutlineRegion) {
  330. SinglePredFromOutlineRegion = Pred;
  331. } else if (SinglePredFromOutlineRegion != Pred) {
  332. SinglePredFromOutlineRegion = nullptr;
  333. break;
  334. }
  335. }
  336. if (SinglePredFromOutlineRegion)
  337. return SinglePredFromOutlineRegion;
  338. #ifndef NDEBUG
  339. auto getFirstPHI = [](BasicBlock *BB) {
  340. BasicBlock::iterator I = BB->begin();
  341. PHINode *FirstPhi = nullptr;
  342. while (I != BB->end()) {
  343. PHINode *Phi = dyn_cast<PHINode>(I);
  344. if (!Phi)
  345. break;
  346. if (!FirstPhi) {
  347. FirstPhi = Phi;
  348. break;
  349. }
  350. }
  351. return FirstPhi;
  352. };
  353. // If there are any phi nodes, the single pred either exists or has already
  354. // be created before code extraction.
  355. assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
  356. #endif
  357. BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
  358. CommonExitBlock->getFirstNonPHI()->getIterator());
  359. for (auto PI = pred_begin(CommonExitBlock), PE = pred_end(CommonExitBlock);
  360. PI != PE;) {
  361. BasicBlock *Pred = *PI++;
  362. if (Blocks.count(Pred))
  363. continue;
  364. Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
  365. }
  366. // Now add the old exit block to the outline region.
  367. Blocks.insert(CommonExitBlock);
  368. return CommonExitBlock;
  369. }
  370. // Find the pair of life time markers for address 'Addr' that are either
  371. // defined inside the outline region or can legally be shrinkwrapped into the
  372. // outline region. If there are not other untracked uses of the address, return
  373. // the pair of markers if found; otherwise return a pair of nullptr.
  374. CodeExtractor::LifetimeMarkerInfo
  375. CodeExtractor::getLifetimeMarkers(Instruction *Addr,
  376. BasicBlock *ExitBlock) const {
  377. LifetimeMarkerInfo Info;
  378. for (User *U : Addr->users()) {
  379. IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
  380. if (IntrInst) {
  381. if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
  382. // Do not handle the case where Addr has multiple start markers.
  383. if (Info.LifeStart)
  384. return {};
  385. Info.LifeStart = IntrInst;
  386. }
  387. if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
  388. if (Info.LifeEnd)
  389. return {};
  390. Info.LifeEnd = IntrInst;
  391. }
  392. continue;
  393. }
  394. // Find untracked uses of the address, bail.
  395. if (!definedInRegion(Blocks, U))
  396. return {};
  397. }
  398. if (!Info.LifeStart || !Info.LifeEnd)
  399. return {};
  400. Info.SinkLifeStart = !definedInRegion(Blocks, Info.LifeStart);
  401. Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd);
  402. // Do legality check.
  403. if ((Info.SinkLifeStart || Info.HoistLifeEnd) &&
  404. !isLegalToShrinkwrapLifetimeMarkers(Addr))
  405. return {};
  406. // Check to see if we have a place to do hoisting, if not, bail.
  407. if (Info.HoistLifeEnd && !ExitBlock)
  408. return {};
  409. return Info;
  410. }
  411. void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
  412. BasicBlock *&ExitBlock) const {
  413. Function *Func = (*Blocks.begin())->getParent();
  414. ExitBlock = getCommonExitBlock(Blocks);
  415. auto moveOrIgnoreLifetimeMarkers =
  416. [&](const LifetimeMarkerInfo &LMI) -> bool {
  417. if (!LMI.LifeStart)
  418. return false;
  419. if (LMI.SinkLifeStart) {
  420. LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI.LifeStart
  421. << "\n");
  422. SinkCands.insert(LMI.LifeStart);
  423. }
  424. if (LMI.HoistLifeEnd) {
  425. LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI.LifeEnd << "\n");
  426. HoistCands.insert(LMI.LifeEnd);
  427. }
  428. return true;
  429. };
  430. for (BasicBlock &BB : *Func) {
  431. if (Blocks.count(&BB))
  432. continue;
  433. for (Instruction &II : BB) {
  434. auto *AI = dyn_cast<AllocaInst>(&II);
  435. if (!AI)
  436. continue;
  437. LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(AI, ExitBlock);
  438. bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo);
  439. if (Moved) {
  440. LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n");
  441. SinkCands.insert(AI);
  442. continue;
  443. }
  444. // Follow any bitcasts.
  445. SmallVector<Instruction *, 2> Bitcasts;
  446. SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo;
  447. for (User *U : AI->users()) {
  448. if (U->stripInBoundsConstantOffsets() == AI) {
  449. Instruction *Bitcast = cast<Instruction>(U);
  450. LifetimeMarkerInfo LMI = getLifetimeMarkers(Bitcast, ExitBlock);
  451. if (LMI.LifeStart) {
  452. Bitcasts.push_back(Bitcast);
  453. BitcastLifetimeInfo.push_back(LMI);
  454. continue;
  455. }
  456. }
  457. // Found unknown use of AI.
  458. if (!definedInRegion(Blocks, U)) {
  459. Bitcasts.clear();
  460. break;
  461. }
  462. }
  463. // Either no bitcasts reference the alloca or there are unknown uses.
  464. if (Bitcasts.empty())
  465. continue;
  466. LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n");
  467. SinkCands.insert(AI);
  468. for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) {
  469. Instruction *BitcastAddr = Bitcasts[I];
  470. const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I];
  471. assert(LMI.LifeStart &&
  472. "Unsafe to sink bitcast without lifetime markers");
  473. moveOrIgnoreLifetimeMarkers(LMI);
  474. if (!definedInRegion(Blocks, BitcastAddr)) {
  475. LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
  476. << "\n");
  477. SinkCands.insert(BitcastAddr);
  478. }
  479. }
  480. }
  481. }
  482. }
  483. void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
  484. const ValueSet &SinkCands) const {
  485. for (BasicBlock *BB : Blocks) {
  486. // If a used value is defined outside the region, it's an input. If an
  487. // instruction is used outside the region, it's an output.
  488. for (Instruction &II : *BB) {
  489. for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
  490. ++OI) {
  491. Value *V = *OI;
  492. if (!SinkCands.count(V) && definedInCaller(Blocks, V))
  493. Inputs.insert(V);
  494. }
  495. for (User *U : II.users())
  496. if (!definedInRegion(Blocks, U)) {
  497. Outputs.insert(&II);
  498. break;
  499. }
  500. }
  501. }
  502. }
  503. /// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
  504. /// of the region, we need to split the entry block of the region so that the
  505. /// PHI node is easier to deal with.
  506. void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
  507. unsigned NumPredsFromRegion = 0;
  508. unsigned NumPredsOutsideRegion = 0;
  509. if (Header != &Header->getParent()->getEntryBlock()) {
  510. PHINode *PN = dyn_cast<PHINode>(Header->begin());
  511. if (!PN) return; // No PHI nodes.
  512. // If the header node contains any PHI nodes, check to see if there is more
  513. // than one entry from outside the region. If so, we need to sever the
  514. // header block into two.
  515. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
  516. if (Blocks.count(PN->getIncomingBlock(i)))
  517. ++NumPredsFromRegion;
  518. else
  519. ++NumPredsOutsideRegion;
  520. // If there is one (or fewer) predecessor from outside the region, we don't
  521. // need to do anything special.
  522. if (NumPredsOutsideRegion <= 1) return;
  523. }
  524. // Otherwise, we need to split the header block into two pieces: one
  525. // containing PHI nodes merging values from outside of the region, and a
  526. // second that contains all of the code for the block and merges back any
  527. // incoming values from inside of the region.
  528. BasicBlock *NewBB = SplitBlock(Header, Header->getFirstNonPHI(), DT);
  529. // We only want to code extract the second block now, and it becomes the new
  530. // header of the region.
  531. BasicBlock *OldPred = Header;
  532. Blocks.remove(OldPred);
  533. Blocks.insert(NewBB);
  534. Header = NewBB;
  535. // Okay, now we need to adjust the PHI nodes and any branches from within the
  536. // region to go to the new header block instead of the old header block.
  537. if (NumPredsFromRegion) {
  538. PHINode *PN = cast<PHINode>(OldPred->begin());
  539. // Loop over all of the predecessors of OldPred that are in the region,
  540. // changing them to branch to NewBB instead.
  541. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
  542. if (Blocks.count(PN->getIncomingBlock(i))) {
  543. Instruction *TI = PN->getIncomingBlock(i)->getTerminator();
  544. TI->replaceUsesOfWith(OldPred, NewBB);
  545. }
  546. // Okay, everything within the region is now branching to the right block, we
  547. // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
  548. BasicBlock::iterator AfterPHIs;
  549. for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
  550. PHINode *PN = cast<PHINode>(AfterPHIs);
  551. // Create a new PHI node in the new region, which has an incoming value
  552. // from OldPred of PN.
  553. PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
  554. PN->getName() + ".ce", &NewBB->front());
  555. PN->replaceAllUsesWith(NewPN);
  556. NewPN->addIncoming(PN, OldPred);
  557. // Loop over all of the incoming value in PN, moving them to NewPN if they
  558. // are from the extracted region.
  559. for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
  560. if (Blocks.count(PN->getIncomingBlock(i))) {
  561. NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
  562. PN->removeIncomingValue(i);
  563. --i;
  564. }
  565. }
  566. }
  567. }
  568. }
  569. /// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
  570. /// outlined region, we split these PHIs on two: one with inputs from region
  571. /// and other with remaining incoming blocks; then first PHIs are placed in
  572. /// outlined region.
  573. void CodeExtractor::severSplitPHINodesOfExits(
  574. const SmallPtrSetImpl<BasicBlock *> &Exits) {
  575. for (BasicBlock *ExitBB : Exits) {
  576. BasicBlock *NewBB = nullptr;
  577. for (PHINode &PN : ExitBB->phis()) {
  578. // Find all incoming values from the outlining region.
  579. SmallVector<unsigned, 2> IncomingVals;
  580. for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
  581. if (Blocks.count(PN.getIncomingBlock(i)))
  582. IncomingVals.push_back(i);
  583. // Do not process PHI if there is one (or fewer) predecessor from region.
  584. // If PHI has exactly one predecessor from region, only this one incoming
  585. // will be replaced on codeRepl block, so it should be safe to skip PHI.
  586. if (IncomingVals.size() <= 1)
  587. continue;
  588. // Create block for new PHIs and add it to the list of outlined if it
  589. // wasn't done before.
  590. if (!NewBB) {
  591. NewBB = BasicBlock::Create(ExitBB->getContext(),
  592. ExitBB->getName() + ".split",
  593. ExitBB->getParent(), ExitBB);
  594. SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBB),
  595. pred_end(ExitBB));
  596. for (BasicBlock *PredBB : Preds)
  597. if (Blocks.count(PredBB))
  598. PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB);
  599. BranchInst::Create(ExitBB, NewBB);
  600. Blocks.insert(NewBB);
  601. }
  602. // Split this PHI.
  603. PHINode *NewPN =
  604. PHINode::Create(PN.getType(), IncomingVals.size(),
  605. PN.getName() + ".ce", NewBB->getFirstNonPHI());
  606. for (unsigned i : IncomingVals)
  607. NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
  608. for (unsigned i : reverse(IncomingVals))
  609. PN.removeIncomingValue(i, false);
  610. PN.addIncoming(NewPN, NewBB);
  611. }
  612. }
  613. }
  614. void CodeExtractor::splitReturnBlocks() {
  615. for (BasicBlock *Block : Blocks)
  616. if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
  617. BasicBlock *New =
  618. Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
  619. if (DT) {
  620. // Old dominates New. New node dominates all other nodes dominated
  621. // by Old.
  622. DomTreeNode *OldNode = DT->getNode(Block);
  623. SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
  624. OldNode->end());
  625. DomTreeNode *NewNode = DT->addNewBlock(New, Block);
  626. for (DomTreeNode *I : Children)
  627. DT->changeImmediateDominator(I, NewNode);
  628. }
  629. }
  630. }
  631. /// constructFunction - make a function based on inputs and outputs, as follows:
  632. /// f(in0, ..., inN, out0, ..., outN)
  633. Function *CodeExtractor::constructFunction(const ValueSet &inputs,
  634. const ValueSet &outputs,
  635. BasicBlock *header,
  636. BasicBlock *newRootNode,
  637. BasicBlock *newHeader,
  638. Function *oldFunction,
  639. Module *M) {
  640. LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
  641. LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
  642. // This function returns unsigned, outputs will go back by reference.
  643. switch (NumExitBlocks) {
  644. case 0:
  645. case 1: RetTy = Type::getVoidTy(header->getContext()); break;
  646. case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
  647. default: RetTy = Type::getInt16Ty(header->getContext()); break;
  648. }
  649. std::vector<Type *> paramTy;
  650. // Add the types of the input values to the function's argument list
  651. for (Value *value : inputs) {
  652. LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
  653. paramTy.push_back(value->getType());
  654. }
  655. // Add the types of the output values to the function's argument list.
  656. for (Value *output : outputs) {
  657. LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
  658. if (AggregateArgs)
  659. paramTy.push_back(output->getType());
  660. else
  661. paramTy.push_back(PointerType::getUnqual(output->getType()));
  662. }
  663. LLVM_DEBUG({
  664. dbgs() << "Function type: " << *RetTy << " f(";
  665. for (Type *i : paramTy)
  666. dbgs() << *i << ", ";
  667. dbgs() << ")\n";
  668. });
  669. StructType *StructTy;
  670. if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
  671. StructTy = StructType::get(M->getContext(), paramTy);
  672. paramTy.clear();
  673. paramTy.push_back(PointerType::getUnqual(StructTy));
  674. }
  675. FunctionType *funcType =
  676. FunctionType::get(RetTy, paramTy,
  677. AllowVarArgs && oldFunction->isVarArg());
  678. std::string SuffixToUse =
  679. Suffix.empty()
  680. ? (header->getName().empty() ? "extracted" : header->getName().str())
  681. : Suffix;
  682. // Create the new function
  683. Function *newFunction = Function::Create(
  684. funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
  685. oldFunction->getName() + "." + SuffixToUse, M);
  686. // If the old function is no-throw, so is the new one.
  687. if (oldFunction->doesNotThrow())
  688. newFunction->setDoesNotThrow();
  689. // Inherit the uwtable attribute if we need to.
  690. if (oldFunction->hasUWTable())
  691. newFunction->setHasUWTable();
  692. // Inherit all of the target dependent attributes and white-listed
  693. // target independent attributes.
  694. // (e.g. If the extracted region contains a call to an x86.sse
  695. // instruction we need to make sure that the extracted region has the
  696. // "target-features" attribute allowing it to be lowered.
  697. // FIXME: This should be changed to check to see if a specific
  698. // attribute can not be inherited.
  699. for (const auto &Attr : oldFunction->getAttributes().getFnAttributes()) {
  700. if (Attr.isStringAttribute()) {
  701. if (Attr.getKindAsString() == "thunk")
  702. continue;
  703. } else
  704. switch (Attr.getKindAsEnum()) {
  705. // Those attributes cannot be propagated safely. Explicitly list them
  706. // here so we get a warning if new attributes are added. This list also
  707. // includes non-function attributes.
  708. case Attribute::Alignment:
  709. case Attribute::AllocSize:
  710. case Attribute::ArgMemOnly:
  711. case Attribute::Builtin:
  712. case Attribute::ByVal:
  713. case Attribute::Convergent:
  714. case Attribute::Dereferenceable:
  715. case Attribute::DereferenceableOrNull:
  716. case Attribute::InAlloca:
  717. case Attribute::InReg:
  718. case Attribute::InaccessibleMemOnly:
  719. case Attribute::InaccessibleMemOrArgMemOnly:
  720. case Attribute::JumpTable:
  721. case Attribute::Naked:
  722. case Attribute::Nest:
  723. case Attribute::NoAlias:
  724. case Attribute::NoBuiltin:
  725. case Attribute::NoCapture:
  726. case Attribute::NoReturn:
  727. case Attribute::NoSync:
  728. case Attribute::None:
  729. case Attribute::NonNull:
  730. case Attribute::ReadNone:
  731. case Attribute::ReadOnly:
  732. case Attribute::Returned:
  733. case Attribute::ReturnsTwice:
  734. case Attribute::SExt:
  735. case Attribute::Speculatable:
  736. case Attribute::StackAlignment:
  737. case Attribute::StructRet:
  738. case Attribute::SwiftError:
  739. case Attribute::SwiftSelf:
  740. case Attribute::WillReturn:
  741. case Attribute::WriteOnly:
  742. case Attribute::ZExt:
  743. case Attribute::ImmArg:
  744. case Attribute::EndAttrKinds:
  745. continue;
  746. // Those attributes should be safe to propagate to the extracted function.
  747. case Attribute::AlwaysInline:
  748. case Attribute::Cold:
  749. case Attribute::NoRecurse:
  750. case Attribute::InlineHint:
  751. case Attribute::MinSize:
  752. case Attribute::NoDuplicate:
  753. case Attribute::NoFree:
  754. case Attribute::NoImplicitFloat:
  755. case Attribute::NoInline:
  756. case Attribute::NonLazyBind:
  757. case Attribute::NoRedZone:
  758. case Attribute::NoUnwind:
  759. case Attribute::OptForFuzzing:
  760. case Attribute::OptimizeNone:
  761. case Attribute::OptimizeForSize:
  762. case Attribute::SafeStack:
  763. case Attribute::ShadowCallStack:
  764. case Attribute::SanitizeAddress:
  765. case Attribute::SanitizeMemory:
  766. case Attribute::SanitizeThread:
  767. case Attribute::SanitizeHWAddress:
  768. case Attribute::SpeculativeLoadHardening:
  769. case Attribute::StackProtect:
  770. case Attribute::StackProtectReq:
  771. case Attribute::StackProtectStrong:
  772. case Attribute::StrictFP:
  773. case Attribute::UWTable:
  774. case Attribute::NoCfCheck:
  775. break;
  776. }
  777. newFunction->addFnAttr(Attr);
  778. }
  779. newFunction->getBasicBlockList().push_back(newRootNode);
  780. // Create an iterator to name all of the arguments we inserted.
  781. Function::arg_iterator AI = newFunction->arg_begin();
  782. // Rewrite all users of the inputs in the extracted region to use the
  783. // arguments (or appropriate addressing into struct) instead.
  784. for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
  785. Value *RewriteVal;
  786. if (AggregateArgs) {
  787. Value *Idx[2];
  788. Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
  789. Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
  790. Instruction *TI = newFunction->begin()->getTerminator();
  791. GetElementPtrInst *GEP = GetElementPtrInst::Create(
  792. StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
  793. RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
  794. "loadgep_" + inputs[i]->getName(), TI);
  795. } else
  796. RewriteVal = &*AI++;
  797. std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
  798. for (User *use : Users)
  799. if (Instruction *inst = dyn_cast<Instruction>(use))
  800. if (Blocks.count(inst->getParent()))
  801. inst->replaceUsesOfWith(inputs[i], RewriteVal);
  802. }
  803. // Set names for input and output arguments.
  804. if (!AggregateArgs) {
  805. AI = newFunction->arg_begin();
  806. for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
  807. AI->setName(inputs[i]->getName());
  808. for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
  809. AI->setName(outputs[i]->getName()+".out");
  810. }
  811. // Rewrite branches to basic blocks outside of the loop to new dummy blocks
  812. // within the new function. This must be done before we lose track of which
  813. // blocks were originally in the code region.
  814. std::vector<User *> Users(header->user_begin(), header->user_end());
  815. for (unsigned i = 0, e = Users.size(); i != e; ++i)
  816. // The BasicBlock which contains the branch is not in the region
  817. // modify the branch target to a new block
  818. if (Instruction *I = dyn_cast<Instruction>(Users[i]))
  819. if (I->isTerminator() && !Blocks.count(I->getParent()) &&
  820. I->getParent()->getParent() == oldFunction)
  821. I->replaceUsesOfWith(header, newHeader);
  822. return newFunction;
  823. }
  824. /// Erase lifetime.start markers which reference inputs to the extraction
  825. /// region, and insert the referenced memory into \p LifetimesStart.
  826. ///
  827. /// The extraction region is defined by a set of blocks (\p Blocks), and a set
  828. /// of allocas which will be moved from the caller function into the extracted
  829. /// function (\p SunkAllocas).
  830. static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
  831. const SetVector<Value *> &SunkAllocas,
  832. SetVector<Value *> &LifetimesStart) {
  833. for (BasicBlock *BB : Blocks) {
  834. for (auto It = BB->begin(), End = BB->end(); It != End;) {
  835. auto *II = dyn_cast<IntrinsicInst>(&*It);
  836. ++It;
  837. if (!II || !II->isLifetimeStartOrEnd())
  838. continue;
  839. // Get the memory operand of the lifetime marker. If the underlying
  840. // object is a sunk alloca, or is otherwise defined in the extraction
  841. // region, the lifetime marker must not be erased.
  842. Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
  843. if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
  844. continue;
  845. if (II->getIntrinsicID() == Intrinsic::lifetime_start)
  846. LifetimesStart.insert(Mem);
  847. II->eraseFromParent();
  848. }
  849. }
  850. }
  851. /// Insert lifetime start/end markers surrounding the call to the new function
  852. /// for objects defined in the caller.
  853. static void insertLifetimeMarkersSurroundingCall(
  854. Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd,
  855. CallInst *TheCall) {
  856. LLVMContext &Ctx = M->getContext();
  857. auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
  858. auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
  859. Instruction *Term = TheCall->getParent()->getTerminator();
  860. // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts
  861. // needed to satisfy this requirement so they may be reused.
  862. DenseMap<Value *, Value *> Bitcasts;
  863. // Emit lifetime markers for the pointers given in \p Objects. Insert the
  864. // markers before the call if \p InsertBefore, and after the call otherwise.
  865. auto insertMarkers = [&](Function *MarkerFunc, ArrayRef<Value *> Objects,
  866. bool InsertBefore) {
  867. for (Value *Mem : Objects) {
  868. assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() ==
  869. TheCall->getFunction()) &&
  870. "Input memory not defined in original function");
  871. Value *&MemAsI8Ptr = Bitcasts[Mem];
  872. if (!MemAsI8Ptr) {
  873. if (Mem->getType() == Int8PtrTy)
  874. MemAsI8Ptr = Mem;
  875. else
  876. MemAsI8Ptr =
  877. CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
  878. }
  879. auto Marker = CallInst::Create(MarkerFunc, {NegativeOne, MemAsI8Ptr});
  880. if (InsertBefore)
  881. Marker->insertBefore(TheCall);
  882. else
  883. Marker->insertBefore(Term);
  884. }
  885. };
  886. if (!LifetimesStart.empty()) {
  887. auto StartFn = llvm::Intrinsic::getDeclaration(
  888. M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
  889. insertMarkers(StartFn, LifetimesStart, /*InsertBefore=*/true);
  890. }
  891. if (!LifetimesEnd.empty()) {
  892. auto EndFn = llvm::Intrinsic::getDeclaration(
  893. M, llvm::Intrinsic::lifetime_end, Int8PtrTy);
  894. insertMarkers(EndFn, LifetimesEnd, /*InsertBefore=*/false);
  895. }
  896. }
  897. /// emitCallAndSwitchStatement - This method sets up the caller side by adding
  898. /// the call instruction, splitting any PHI nodes in the header block as
  899. /// necessary.
  900. CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
  901. BasicBlock *codeReplacer,
  902. ValueSet &inputs,
  903. ValueSet &outputs) {
  904. // Emit a call to the new function, passing in: *pointer to struct (if
  905. // aggregating parameters), or plan inputs and allocated memory for outputs
  906. std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
  907. Module *M = newFunction->getParent();
  908. LLVMContext &Context = M->getContext();
  909. const DataLayout &DL = M->getDataLayout();
  910. CallInst *call = nullptr;
  911. // Add inputs as params, or to be filled into the struct
  912. unsigned ArgNo = 0;
  913. SmallVector<unsigned, 1> SwiftErrorArgs;
  914. for (Value *input : inputs) {
  915. if (AggregateArgs)
  916. StructValues.push_back(input);
  917. else {
  918. params.push_back(input);
  919. if (input->isSwiftError())
  920. SwiftErrorArgs.push_back(ArgNo);
  921. }
  922. ++ArgNo;
  923. }
  924. // Create allocas for the outputs
  925. for (Value *output : outputs) {
  926. if (AggregateArgs) {
  927. StructValues.push_back(output);
  928. } else {
  929. AllocaInst *alloca =
  930. new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
  931. nullptr, output->getName() + ".loc",
  932. &codeReplacer->getParent()->front().front());
  933. ReloadOutputs.push_back(alloca);
  934. params.push_back(alloca);
  935. }
  936. }
  937. StructType *StructArgTy = nullptr;
  938. AllocaInst *Struct = nullptr;
  939. if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
  940. std::vector<Type *> ArgTypes;
  941. for (ValueSet::iterator v = StructValues.begin(),
  942. ve = StructValues.end(); v != ve; ++v)
  943. ArgTypes.push_back((*v)->getType());
  944. // Allocate a struct at the beginning of this function
  945. StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
  946. Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
  947. "structArg",
  948. &codeReplacer->getParent()->front().front());
  949. params.push_back(Struct);
  950. for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
  951. Value *Idx[2];
  952. Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
  953. Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
  954. GetElementPtrInst *GEP = GetElementPtrInst::Create(
  955. StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
  956. codeReplacer->getInstList().push_back(GEP);
  957. StoreInst *SI = new StoreInst(StructValues[i], GEP);
  958. codeReplacer->getInstList().push_back(SI);
  959. }
  960. }
  961. // Emit the call to the function
  962. call = CallInst::Create(newFunction, params,
  963. NumExitBlocks > 1 ? "targetBlock" : "");
  964. // Add debug location to the new call, if the original function has debug
  965. // info. In that case, the terminator of the entry block of the extracted
  966. // function contains the first debug location of the extracted function,
  967. // set in extractCodeRegion.
  968. if (codeReplacer->getParent()->getSubprogram()) {
  969. if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
  970. call->setDebugLoc(DL);
  971. }
  972. codeReplacer->getInstList().push_back(call);
  973. // Set swifterror parameter attributes.
  974. for (unsigned SwiftErrArgNo : SwiftErrorArgs) {
  975. call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
  976. newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
  977. }
  978. Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
  979. unsigned FirstOut = inputs.size();
  980. if (!AggregateArgs)
  981. std::advance(OutputArgBegin, inputs.size());
  982. // Reload the outputs passed in by reference.
  983. for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
  984. Value *Output = nullptr;
  985. if (AggregateArgs) {
  986. Value *Idx[2];
  987. Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
  988. Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
  989. GetElementPtrInst *GEP = GetElementPtrInst::Create(
  990. StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
  991. codeReplacer->getInstList().push_back(GEP);
  992. Output = GEP;
  993. } else {
  994. Output = ReloadOutputs[i];
  995. }
  996. LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
  997. outputs[i]->getName() + ".reload");
  998. Reloads.push_back(load);
  999. codeReplacer->getInstList().push_back(load);
  1000. std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
  1001. for (unsigned u = 0, e = Users.size(); u != e; ++u) {
  1002. Instruction *inst = cast<Instruction>(Users[u]);
  1003. if (!Blocks.count(inst->getParent()))
  1004. inst->replaceUsesOfWith(outputs[i], load);
  1005. }
  1006. }
  1007. // Now we can emit a switch statement using the call as a value.
  1008. SwitchInst *TheSwitch =
  1009. SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
  1010. codeReplacer, 0, codeReplacer);
  1011. // Since there may be multiple exits from the original region, make the new
  1012. // function return an unsigned, switch on that number. This loop iterates
  1013. // over all of the blocks in the extracted region, updating any terminator
  1014. // instructions in the to-be-extracted region that branch to blocks that are
  1015. // not in the region to be extracted.
  1016. std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
  1017. unsigned switchVal = 0;
  1018. for (BasicBlock *Block : Blocks) {
  1019. Instruction *TI = Block->getTerminator();
  1020. for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
  1021. if (!Blocks.count(TI->getSuccessor(i))) {
  1022. BasicBlock *OldTarget = TI->getSuccessor(i);
  1023. // add a new basic block which returns the appropriate value
  1024. BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
  1025. if (!NewTarget) {
  1026. // If we don't already have an exit stub for this non-extracted
  1027. // destination, create one now!
  1028. NewTarget = BasicBlock::Create(Context,
  1029. OldTarget->getName() + ".exitStub",
  1030. newFunction);
  1031. unsigned SuccNum = switchVal++;
  1032. Value *brVal = nullptr;
  1033. switch (NumExitBlocks) {
  1034. case 0:
  1035. case 1: break; // No value needed.
  1036. case 2: // Conditional branch, return a bool
  1037. brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
  1038. break;
  1039. default:
  1040. brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
  1041. break;
  1042. }
  1043. ReturnInst::Create(Context, brVal, NewTarget);
  1044. // Update the switch instruction.
  1045. TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
  1046. SuccNum),
  1047. OldTarget);
  1048. }
  1049. // rewrite the original branch instruction with this new target
  1050. TI->setSuccessor(i, NewTarget);
  1051. }
  1052. }
  1053. // Store the arguments right after the definition of output value.
  1054. // This should be proceeded after creating exit stubs to be ensure that invoke
  1055. // result restore will be placed in the outlined function.
  1056. Function::arg_iterator OAI = OutputArgBegin;
  1057. for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
  1058. auto *OutI = dyn_cast<Instruction>(outputs[i]);
  1059. if (!OutI)
  1060. continue;
  1061. // Find proper insertion point.
  1062. BasicBlock::iterator InsertPt;
  1063. // In case OutI is an invoke, we insert the store at the beginning in the
  1064. // 'normal destination' BB. Otherwise we insert the store right after OutI.
  1065. if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
  1066. InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
  1067. else if (auto *Phi = dyn_cast<PHINode>(OutI))
  1068. InsertPt = Phi->getParent()->getFirstInsertionPt();
  1069. else
  1070. InsertPt = std::next(OutI->getIterator());
  1071. Instruction *InsertBefore = &*InsertPt;
  1072. assert((InsertBefore->getFunction() == newFunction ||
  1073. Blocks.count(InsertBefore->getParent())) &&
  1074. "InsertPt should be in new function");
  1075. assert(OAI != newFunction->arg_end() &&
  1076. "Number of output arguments should match "
  1077. "the amount of defined values");
  1078. if (AggregateArgs) {
  1079. Value *Idx[2];
  1080. Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
  1081. Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
  1082. GetElementPtrInst *GEP = GetElementPtrInst::Create(
  1083. StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
  1084. InsertBefore);
  1085. new StoreInst(outputs[i], GEP, InsertBefore);
  1086. // Since there should be only one struct argument aggregating
  1087. // all the output values, we shouldn't increment OAI, which always
  1088. // points to the struct argument, in this case.
  1089. } else {
  1090. new StoreInst(outputs[i], &*OAI, InsertBefore);
  1091. ++OAI;
  1092. }
  1093. }
  1094. // Now that we've done the deed, simplify the switch instruction.
  1095. Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
  1096. switch (NumExitBlocks) {
  1097. case 0:
  1098. // There are no successors (the block containing the switch itself), which
  1099. // means that previously this was the last part of the function, and hence
  1100. // this should be rewritten as a `ret'
  1101. // Check if the function should return a value
  1102. if (OldFnRetTy->isVoidTy()) {
  1103. ReturnInst::Create(Context, nullptr, TheSwitch); // Return void
  1104. } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
  1105. // return what we have
  1106. ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
  1107. } else {
  1108. // Otherwise we must have code extracted an unwind or something, just
  1109. // return whatever we want.
  1110. ReturnInst::Create(Context,
  1111. Constant::getNullValue(OldFnRetTy), TheSwitch);
  1112. }
  1113. TheSwitch->eraseFromParent();
  1114. break;
  1115. case 1:
  1116. // Only a single destination, change the switch into an unconditional
  1117. // branch.
  1118. BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
  1119. TheSwitch->eraseFromParent();
  1120. break;
  1121. case 2:
  1122. BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
  1123. call, TheSwitch);
  1124. TheSwitch->eraseFromParent();
  1125. break;
  1126. default:
  1127. // Otherwise, make the default destination of the switch instruction be one
  1128. // of the other successors.
  1129. TheSwitch->setCondition(call);
  1130. TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
  1131. // Remove redundant case
  1132. TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
  1133. break;
  1134. }
  1135. // Insert lifetime markers around the reloads of any output values. The
  1136. // allocas output values are stored in are only in-use in the codeRepl block.
  1137. insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call);
  1138. return call;
  1139. }
  1140. void CodeExtractor::moveCodeToFunction(Function *newFunction) {
  1141. Function *oldFunc = (*Blocks.begin())->getParent();
  1142. Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
  1143. Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
  1144. for (BasicBlock *Block : Blocks) {
  1145. // Delete the basic block from the old function, and the list of blocks
  1146. oldBlocks.remove(Block);
  1147. // Insert this basic block into the new function
  1148. newBlocks.push_back(Block);
  1149. // Remove @llvm.assume calls that were moved to the new function from the
  1150. // old function's assumption cache.
  1151. if (AC)
  1152. for (auto &I : *Block)
  1153. if (match(&I, m_Intrinsic<Intrinsic::assume>()))
  1154. AC->unregisterAssumption(cast<CallInst>(&I));
  1155. }
  1156. }
  1157. void CodeExtractor::calculateNewCallTerminatorWeights(
  1158. BasicBlock *CodeReplacer,
  1159. DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
  1160. BranchProbabilityInfo *BPI) {
  1161. using Distribution = BlockFrequencyInfoImplBase::Distribution;
  1162. using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
  1163. // Update the branch weights for the exit block.
  1164. Instruction *TI = CodeReplacer->getTerminator();
  1165. SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
  1166. // Block Frequency distribution with dummy node.
  1167. Distribution BranchDist;
  1168. // Add each of the frequencies of the successors.
  1169. for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
  1170. BlockNode ExitNode(i);
  1171. uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
  1172. if (ExitFreq != 0)
  1173. BranchDist.addExit(ExitNode, ExitFreq);
  1174. else
  1175. BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
  1176. }
  1177. // Check for no total weight.
  1178. if (BranchDist.Total == 0)
  1179. return;
  1180. // Normalize the distribution so that they can fit in unsigned.
  1181. BranchDist.normalize();
  1182. // Create normalized branch weights and set the metadata.
  1183. for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
  1184. const auto &Weight = BranchDist.Weights[I];
  1185. // Get the weight and update the current BFI.
  1186. BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
  1187. BranchProbability BP(Weight.Amount, BranchDist.Total);
  1188. BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
  1189. }
  1190. TI->setMetadata(
  1191. LLVMContext::MD_prof,
  1192. MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
  1193. }
  1194. Function *CodeExtractor::extractCodeRegion() {
  1195. if (!isEligible())
  1196. return nullptr;
  1197. // Assumption: this is a single-entry code region, and the header is the first
  1198. // block in the region.
  1199. BasicBlock *header = *Blocks.begin();
  1200. Function *oldFunction = header->getParent();
  1201. // For functions with varargs, check that varargs handling is only done in the
  1202. // outlined function, i.e vastart and vaend are only used in outlined blocks.
  1203. if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) {
  1204. auto containsVarArgIntrinsic = [](Instruction &I) {
  1205. if (const CallInst *CI = dyn_cast<CallInst>(&I))
  1206. if (const Function *F = CI->getCalledFunction())
  1207. return F->getIntrinsicID() == Intrinsic::vastart ||
  1208. F->getIntrinsicID() == Intrinsic::vaend;
  1209. return false;
  1210. };
  1211. for (auto &BB : *oldFunction) {
  1212. if (Blocks.count(&BB))
  1213. continue;
  1214. if (llvm::any_of(BB, containsVarArgIntrinsic))
  1215. return nullptr;
  1216. }
  1217. }
  1218. ValueSet inputs, outputs, SinkingCands, HoistingCands;
  1219. BasicBlock *CommonExit = nullptr;
  1220. // Calculate the entry frequency of the new function before we change the root
  1221. // block.
  1222. BlockFrequency EntryFreq;
  1223. if (BFI) {
  1224. assert(BPI && "Both BPI and BFI are required to preserve profile info");
  1225. for (BasicBlock *Pred : predecessors(header)) {
  1226. if (Blocks.count(Pred))
  1227. continue;
  1228. EntryFreq +=
  1229. BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
  1230. }
  1231. }
  1232. // If we have any return instructions in the region, split those blocks so
  1233. // that the return is not in the region.
  1234. splitReturnBlocks();
  1235. // Calculate the exit blocks for the extracted region and the total exit
  1236. // weights for each of those blocks.
  1237. DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
  1238. SmallPtrSet<BasicBlock *, 1> ExitBlocks;
  1239. for (BasicBlock *Block : Blocks) {
  1240. for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
  1241. ++SI) {
  1242. if (!Blocks.count(*SI)) {
  1243. // Update the branch weight for this successor.
  1244. if (BFI) {
  1245. BlockFrequency &BF = ExitWeights[*SI];
  1246. BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
  1247. }
  1248. ExitBlocks.insert(*SI);
  1249. }
  1250. }
  1251. }
  1252. NumExitBlocks = ExitBlocks.size();
  1253. // If we have to split PHI nodes of the entry or exit blocks, do so now.
  1254. severSplitPHINodesOfEntry(header);
  1255. severSplitPHINodesOfExits(ExitBlocks);
  1256. // This takes place of the original loop
  1257. BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
  1258. "codeRepl", oldFunction,
  1259. header);
  1260. // The new function needs a root node because other nodes can branch to the
  1261. // head of the region, but the entry node of a function cannot have preds.
  1262. BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
  1263. "newFuncRoot");
  1264. auto *BranchI = BranchInst::Create(header);
  1265. // If the original function has debug info, we have to add a debug location
  1266. // to the new branch instruction from the artificial entry block.
  1267. // We use the debug location of the first instruction in the extracted
  1268. // blocks, as there is no other equivalent line in the source code.
  1269. if (oldFunction->getSubprogram()) {
  1270. any_of(Blocks, [&BranchI](const BasicBlock *BB) {
  1271. return any_of(*BB, [&BranchI](const Instruction &I) {
  1272. if (!I.getDebugLoc())
  1273. return false;
  1274. BranchI->setDebugLoc(I.getDebugLoc());
  1275. return true;
  1276. });
  1277. });
  1278. }
  1279. newFuncRoot->getInstList().push_back(BranchI);
  1280. findAllocas(SinkingCands, HoistingCands, CommonExit);
  1281. assert(HoistingCands.empty() || CommonExit);
  1282. // Find inputs to, outputs from the code region.
  1283. findInputsOutputs(inputs, outputs, SinkingCands);
  1284. // Now sink all instructions which only have non-phi uses inside the region.
  1285. // Group the allocas at the start of the block, so that any bitcast uses of
  1286. // the allocas are well-defined.
  1287. AllocaInst *FirstSunkAlloca = nullptr;
  1288. for (auto *II : SinkingCands) {
  1289. if (auto *AI = dyn_cast<AllocaInst>(II)) {
  1290. AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt());
  1291. if (!FirstSunkAlloca)
  1292. FirstSunkAlloca = AI;
  1293. }
  1294. }
  1295. assert((SinkingCands.empty() || FirstSunkAlloca) &&
  1296. "Did not expect a sink candidate without any allocas");
  1297. for (auto *II : SinkingCands) {
  1298. if (!isa<AllocaInst>(II)) {
  1299. cast<Instruction>(II)->moveAfter(FirstSunkAlloca);
  1300. }
  1301. }
  1302. if (!HoistingCands.empty()) {
  1303. auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
  1304. Instruction *TI = HoistToBlock->getTerminator();
  1305. for (auto *II : HoistingCands)
  1306. cast<Instruction>(II)->moveBefore(TI);
  1307. }
  1308. // Collect objects which are inputs to the extraction region and also
  1309. // referenced by lifetime start markers within it. The effects of these
  1310. // markers must be replicated in the calling function to prevent the stack
  1311. // coloring pass from merging slots which store input objects.
  1312. ValueSet LifetimesStart;
  1313. eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart);
  1314. // Construct new function based on inputs/outputs & add allocas for all defs.
  1315. Function *newFunction =
  1316. constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
  1317. oldFunction, oldFunction->getParent());
  1318. // Update the entry count of the function.
  1319. if (BFI) {
  1320. auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
  1321. if (Count.hasValue())
  1322. newFunction->setEntryCount(
  1323. ProfileCount(Count.getValue(), Function::PCT_Real)); // FIXME
  1324. BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
  1325. }
  1326. CallInst *TheCall =
  1327. emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
  1328. moveCodeToFunction(newFunction);
  1329. // Replicate the effects of any lifetime start/end markers which referenced
  1330. // input objects in the extraction region by placing markers around the call.
  1331. insertLifetimeMarkersSurroundingCall(
  1332. oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall);
  1333. // Propagate personality info to the new function if there is one.
  1334. if (oldFunction->hasPersonalityFn())
  1335. newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
  1336. // Update the branch weights for the exit block.
  1337. if (BFI && NumExitBlocks > 1)
  1338. calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
  1339. // Loop over all of the PHI nodes in the header and exit blocks, and change
  1340. // any references to the old incoming edge to be the new incoming edge.
  1341. for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
  1342. PHINode *PN = cast<PHINode>(I);
  1343. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
  1344. if (!Blocks.count(PN->getIncomingBlock(i)))
  1345. PN->setIncomingBlock(i, newFuncRoot);
  1346. }
  1347. for (BasicBlock *ExitBB : ExitBlocks)
  1348. for (PHINode &PN : ExitBB->phis()) {
  1349. Value *IncomingCodeReplacerVal = nullptr;
  1350. for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
  1351. // Ignore incoming values from outside of the extracted region.
  1352. if (!Blocks.count(PN.getIncomingBlock(i)))
  1353. continue;
  1354. // Ensure that there is only one incoming value from codeReplacer.
  1355. if (!IncomingCodeReplacerVal) {
  1356. PN.setIncomingBlock(i, codeReplacer);
  1357. IncomingCodeReplacerVal = PN.getIncomingValue(i);
  1358. } else
  1359. assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
  1360. "PHI has two incompatbile incoming values from codeRepl");
  1361. }
  1362. }
  1363. // Erase debug info intrinsics. Variable updates within the new function are
  1364. // invisible to debuggers. This could be improved by defining a DISubprogram
  1365. // for the new function.
  1366. for (BasicBlock &BB : *newFunction) {
  1367. auto BlockIt = BB.begin();
  1368. // Remove debug info intrinsics from the new function.
  1369. while (BlockIt != BB.end()) {
  1370. Instruction *Inst = &*BlockIt;
  1371. ++BlockIt;
  1372. if (isa<DbgInfoIntrinsic>(Inst))
  1373. Inst->eraseFromParent();
  1374. }
  1375. // Remove debug info intrinsics which refer to values in the new function
  1376. // from the old function.
  1377. SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
  1378. for (Instruction &I : BB)
  1379. findDbgUsers(DbgUsers, &I);
  1380. for (DbgVariableIntrinsic *DVI : DbgUsers)
  1381. DVI->eraseFromParent();
  1382. }
  1383. // Mark the new function `noreturn` if applicable. Terminators which resume
  1384. // exception propagation are treated as returning instructions. This is to
  1385. // avoid inserting traps after calls to outlined functions which unwind.
  1386. bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) {
  1387. const Instruction *Term = BB.getTerminator();
  1388. return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
  1389. });
  1390. if (doesNotReturn)
  1391. newFunction->setDoesNotReturn();
  1392. LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) {
  1393. newFunction->dump();
  1394. report_fatal_error("verification of newFunction failed!");
  1395. });
  1396. LLVM_DEBUG(if (verifyFunction(*oldFunction))
  1397. report_fatal_error("verification of oldFunction failed!"));
  1398. return newFunction;
  1399. }