CodeExtractor.cpp 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472
  1. //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements the interface to tear out a code region, such as an
  10. // individual loop or a parallel section, into a new function, replacing it with
  11. // a call to the new function.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "llvm/Transforms/Utils/CodeExtractor.h"
  15. #include "llvm/ADT/ArrayRef.h"
  16. #include "llvm/ADT/DenseMap.h"
  17. #include "llvm/ADT/Optional.h"
  18. #include "llvm/ADT/STLExtras.h"
  19. #include "llvm/ADT/SetVector.h"
  20. #include "llvm/ADT/SmallPtrSet.h"
  21. #include "llvm/ADT/SmallVector.h"
  22. #include "llvm/Analysis/BlockFrequencyInfo.h"
  23. #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
  24. #include "llvm/Analysis/BranchProbabilityInfo.h"
  25. #include "llvm/Analysis/LoopInfo.h"
  26. #include "llvm/IR/Argument.h"
  27. #include "llvm/IR/Attributes.h"
  28. #include "llvm/IR/BasicBlock.h"
  29. #include "llvm/IR/CFG.h"
  30. #include "llvm/IR/Constant.h"
  31. #include "llvm/IR/Constants.h"
  32. #include "llvm/IR/DataLayout.h"
  33. #include "llvm/IR/DerivedTypes.h"
  34. #include "llvm/IR/Dominators.h"
  35. #include "llvm/IR/Function.h"
  36. #include "llvm/IR/GlobalValue.h"
  37. #include "llvm/IR/InstrTypes.h"
  38. #include "llvm/IR/Instruction.h"
  39. #include "llvm/IR/Instructions.h"
  40. #include "llvm/IR/IntrinsicInst.h"
  41. #include "llvm/IR/Intrinsics.h"
  42. #include "llvm/IR/LLVMContext.h"
  43. #include "llvm/IR/MDBuilder.h"
  44. #include "llvm/IR/Module.h"
  45. #include "llvm/IR/Type.h"
  46. #include "llvm/IR/User.h"
  47. #include "llvm/IR/Value.h"
  48. #include "llvm/IR/Verifier.h"
  49. #include "llvm/Pass.h"
  50. #include "llvm/Support/BlockFrequency.h"
  51. #include "llvm/Support/BranchProbability.h"
  52. #include "llvm/Support/Casting.h"
  53. #include "llvm/Support/CommandLine.h"
  54. #include "llvm/Support/Debug.h"
  55. #include "llvm/Support/ErrorHandling.h"
  56. #include "llvm/Support/raw_ostream.h"
  57. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  58. #include "llvm/Transforms/Utils/Local.h"
  59. #include <cassert>
  60. #include <cstdint>
  61. #include <iterator>
  62. #include <map>
  63. #include <set>
  64. #include <utility>
  65. #include <vector>
  66. using namespace llvm;
  67. using ProfileCount = Function::ProfileCount;
  68. #define DEBUG_TYPE "code-extractor"
  69. // Provide a command-line option to aggregate function arguments into a struct
  70. // for functions produced by the code extractor. This is useful when converting
  71. // extracted functions to pthread-based code, as only one argument (void*) can
  72. // be passed in to pthread_create().
  73. static cl::opt<bool>
  74. AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
  75. cl::desc("Aggregate arguments to code-extracted functions"));
  76. /// Test whether a block is valid for extraction.
  77. static bool isBlockValidForExtraction(const BasicBlock &BB,
  78. const SetVector<BasicBlock *> &Result,
  79. bool AllowVarArgs, bool AllowAlloca) {
  80. // taking the address of a basic block moved to another function is illegal
  81. if (BB.hasAddressTaken())
  82. return false;
  83. // don't hoist code that uses another basicblock address, as it's likely to
  84. // lead to unexpected behavior, like cross-function jumps
  85. SmallPtrSet<User const *, 16> Visited;
  86. SmallVector<User const *, 16> ToVisit;
  87. for (Instruction const &Inst : BB)
  88. ToVisit.push_back(&Inst);
  89. while (!ToVisit.empty()) {
  90. User const *Curr = ToVisit.pop_back_val();
  91. if (!Visited.insert(Curr).second)
  92. continue;
  93. if (isa<BlockAddress const>(Curr))
  94. return false; // even a reference to self is likely to be not compatible
  95. if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
  96. continue;
  97. for (auto const &U : Curr->operands()) {
  98. if (auto *UU = dyn_cast<User>(U))
  99. ToVisit.push_back(UU);
  100. }
  101. }
  102. // If explicitly requested, allow vastart and alloca. For invoke instructions
  103. // verify that extraction is valid.
  104. for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
  105. if (isa<AllocaInst>(I)) {
  106. if (!AllowAlloca)
  107. return false;
  108. continue;
  109. }
  110. if (const auto *II = dyn_cast<InvokeInst>(I)) {
  111. // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
  112. // must be a part of the subgraph which is being extracted.
  113. if (auto *UBB = II->getUnwindDest())
  114. if (!Result.count(UBB))
  115. return false;
  116. continue;
  117. }
  118. // All catch handlers of a catchswitch instruction as well as the unwind
  119. // destination must be in the subgraph.
  120. if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) {
  121. if (auto *UBB = CSI->getUnwindDest())
  122. if (!Result.count(UBB))
  123. return false;
  124. for (auto *HBB : CSI->handlers())
  125. if (!Result.count(const_cast<BasicBlock*>(HBB)))
  126. return false;
  127. continue;
  128. }
  129. // Make sure that entire catch handler is within subgraph. It is sufficient
  130. // to check that catch return's block is in the list.
  131. if (const auto *CPI = dyn_cast<CatchPadInst>(I)) {
  132. for (const auto *U : CPI->users())
  133. if (const auto *CRI = dyn_cast<CatchReturnInst>(U))
  134. if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
  135. return false;
  136. continue;
  137. }
  138. // And do similar checks for cleanup handler - the entire handler must be
  139. // in subgraph which is going to be extracted. For cleanup return should
  140. // additionally check that the unwind destination is also in the subgraph.
  141. if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) {
  142. for (const auto *U : CPI->users())
  143. if (const auto *CRI = dyn_cast<CleanupReturnInst>(U))
  144. if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
  145. return false;
  146. continue;
  147. }
  148. if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) {
  149. if (auto *UBB = CRI->getUnwindDest())
  150. if (!Result.count(UBB))
  151. return false;
  152. continue;
  153. }
  154. if (const CallInst *CI = dyn_cast<CallInst>(I)) {
  155. if (const Function *F = CI->getCalledFunction()) {
  156. auto IID = F->getIntrinsicID();
  157. if (IID == Intrinsic::vastart) {
  158. if (AllowVarArgs)
  159. continue;
  160. else
  161. return false;
  162. }
  163. // Currently, we miscompile outlined copies of eh_typid_for. There are
  164. // proposals for fixing this in llvm.org/PR39545.
  165. if (IID == Intrinsic::eh_typeid_for)
  166. return false;
  167. }
  168. }
  169. }
  170. return true;
  171. }
  172. /// Build a set of blocks to extract if the input blocks are viable.
  173. static SetVector<BasicBlock *>
  174. buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
  175. bool AllowVarArgs, bool AllowAlloca) {
  176. assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
  177. SetVector<BasicBlock *> Result;
  178. // Loop over the blocks, adding them to our set-vector, and aborting with an
  179. // empty set if we encounter invalid blocks.
  180. for (BasicBlock *BB : BBs) {
  181. // If this block is dead, don't process it.
  182. if (DT && !DT->isReachableFromEntry(BB))
  183. continue;
  184. if (!Result.insert(BB))
  185. llvm_unreachable("Repeated basic blocks in extraction input");
  186. }
  187. for (auto *BB : Result) {
  188. if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
  189. return {};
  190. // Make sure that the first block is not a landing pad.
  191. if (BB == Result.front()) {
  192. if (BB->isEHPad()) {
  193. LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
  194. return {};
  195. }
  196. continue;
  197. }
  198. // All blocks other than the first must not have predecessors outside of
  199. // the subgraph which is being extracted.
  200. for (auto *PBB : predecessors(BB))
  201. if (!Result.count(PBB)) {
  202. LLVM_DEBUG(
  203. dbgs() << "No blocks in this region may have entries from "
  204. "outside the region except for the first block!\n");
  205. return {};
  206. }
  207. }
  208. return Result;
  209. }
  210. CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
  211. bool AggregateArgs, BlockFrequencyInfo *BFI,
  212. BranchProbabilityInfo *BPI, bool AllowVarArgs,
  213. bool AllowAlloca, std::string Suffix)
  214. : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
  215. BPI(BPI), AllowVarArgs(AllowVarArgs),
  216. Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
  217. Suffix(Suffix) {}
  218. CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
  219. BlockFrequencyInfo *BFI,
  220. BranchProbabilityInfo *BPI, std::string Suffix)
  221. : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
  222. BPI(BPI), AllowVarArgs(false),
  223. Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
  224. /* AllowVarArgs */ false,
  225. /* AllowAlloca */ false)),
  226. Suffix(Suffix) {}
  227. /// definedInRegion - Return true if the specified value is defined in the
  228. /// extracted region.
  229. static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
  230. if (Instruction *I = dyn_cast<Instruction>(V))
  231. if (Blocks.count(I->getParent()))
  232. return true;
  233. return false;
  234. }
  235. /// definedInCaller - Return true if the specified value is defined in the
  236. /// function being code extracted, but not in the region being extracted.
  237. /// These values must be passed in as live-ins to the function.
  238. static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
  239. if (isa<Argument>(V)) return true;
  240. if (Instruction *I = dyn_cast<Instruction>(V))
  241. if (!Blocks.count(I->getParent()))
  242. return true;
  243. return false;
  244. }
  245. static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
  246. BasicBlock *CommonExitBlock = nullptr;
  247. auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
  248. for (auto *Succ : successors(Block)) {
  249. // Internal edges, ok.
  250. if (Blocks.count(Succ))
  251. continue;
  252. if (!CommonExitBlock) {
  253. CommonExitBlock = Succ;
  254. continue;
  255. }
  256. if (CommonExitBlock == Succ)
  257. continue;
  258. return true;
  259. }
  260. return false;
  261. };
  262. if (any_of(Blocks, hasNonCommonExitSucc))
  263. return nullptr;
  264. return CommonExitBlock;
  265. }
  266. bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
  267. Instruction *Addr) const {
  268. AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
  269. Function *Func = (*Blocks.begin())->getParent();
  270. for (BasicBlock &BB : *Func) {
  271. if (Blocks.count(&BB))
  272. continue;
  273. for (Instruction &II : BB) {
  274. if (isa<DbgInfoIntrinsic>(II))
  275. continue;
  276. unsigned Opcode = II.getOpcode();
  277. Value *MemAddr = nullptr;
  278. switch (Opcode) {
  279. case Instruction::Store:
  280. case Instruction::Load: {
  281. if (Opcode == Instruction::Store) {
  282. StoreInst *SI = cast<StoreInst>(&II);
  283. MemAddr = SI->getPointerOperand();
  284. } else {
  285. LoadInst *LI = cast<LoadInst>(&II);
  286. MemAddr = LI->getPointerOperand();
  287. }
  288. // Global variable can not be aliased with locals.
  289. if (dyn_cast<Constant>(MemAddr))
  290. break;
  291. Value *Base = MemAddr->stripInBoundsConstantOffsets();
  292. if (!dyn_cast<AllocaInst>(Base) || Base == AI)
  293. return false;
  294. break;
  295. }
  296. default: {
  297. IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
  298. if (IntrInst) {
  299. if (IntrInst->isLifetimeStartOrEnd())
  300. break;
  301. return false;
  302. }
  303. // Treat all the other cases conservatively if it has side effects.
  304. if (II.mayHaveSideEffects())
  305. return false;
  306. }
  307. }
  308. }
  309. }
  310. return true;
  311. }
  312. BasicBlock *
  313. CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
  314. BasicBlock *SinglePredFromOutlineRegion = nullptr;
  315. assert(!Blocks.count(CommonExitBlock) &&
  316. "Expect a block outside the region!");
  317. for (auto *Pred : predecessors(CommonExitBlock)) {
  318. if (!Blocks.count(Pred))
  319. continue;
  320. if (!SinglePredFromOutlineRegion) {
  321. SinglePredFromOutlineRegion = Pred;
  322. } else if (SinglePredFromOutlineRegion != Pred) {
  323. SinglePredFromOutlineRegion = nullptr;
  324. break;
  325. }
  326. }
  327. if (SinglePredFromOutlineRegion)
  328. return SinglePredFromOutlineRegion;
  329. #ifndef NDEBUG
  330. auto getFirstPHI = [](BasicBlock *BB) {
  331. BasicBlock::iterator I = BB->begin();
  332. PHINode *FirstPhi = nullptr;
  333. while (I != BB->end()) {
  334. PHINode *Phi = dyn_cast<PHINode>(I);
  335. if (!Phi)
  336. break;
  337. if (!FirstPhi) {
  338. FirstPhi = Phi;
  339. break;
  340. }
  341. }
  342. return FirstPhi;
  343. };
  344. // If there are any phi nodes, the single pred either exists or has already
  345. // be created before code extraction.
  346. assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
  347. #endif
  348. BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
  349. CommonExitBlock->getFirstNonPHI()->getIterator());
  350. for (auto PI = pred_begin(CommonExitBlock), PE = pred_end(CommonExitBlock);
  351. PI != PE;) {
  352. BasicBlock *Pred = *PI++;
  353. if (Blocks.count(Pred))
  354. continue;
  355. Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
  356. }
  357. // Now add the old exit block to the outline region.
  358. Blocks.insert(CommonExitBlock);
  359. return CommonExitBlock;
  360. }
  361. void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
  362. BasicBlock *&ExitBlock) const {
  363. Function *Func = (*Blocks.begin())->getParent();
  364. ExitBlock = getCommonExitBlock(Blocks);
  365. for (BasicBlock &BB : *Func) {
  366. if (Blocks.count(&BB))
  367. continue;
  368. for (Instruction &II : BB) {
  369. auto *AI = dyn_cast<AllocaInst>(&II);
  370. if (!AI)
  371. continue;
  372. // Find the pair of life time markers for address 'Addr' that are either
  373. // defined inside the outline region or can legally be shrinkwrapped into
  374. // the outline region. If there are not other untracked uses of the
  375. // address, return the pair of markers if found; otherwise return a pair
  376. // of nullptr.
  377. auto GetLifeTimeMarkers =
  378. [&](Instruction *Addr, bool &SinkLifeStart,
  379. bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> {
  380. Instruction *LifeStart = nullptr, *LifeEnd = nullptr;
  381. for (User *U : Addr->users()) {
  382. IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
  383. if (IntrInst) {
  384. if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
  385. // Do not handle the case where AI has multiple start markers.
  386. if (LifeStart)
  387. return std::make_pair<Instruction *>(nullptr, nullptr);
  388. LifeStart = IntrInst;
  389. }
  390. if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
  391. if (LifeEnd)
  392. return std::make_pair<Instruction *>(nullptr, nullptr);
  393. LifeEnd = IntrInst;
  394. }
  395. continue;
  396. }
  397. // Find untracked uses of the address, bail.
  398. if (!definedInRegion(Blocks, U))
  399. return std::make_pair<Instruction *>(nullptr, nullptr);
  400. }
  401. if (!LifeStart || !LifeEnd)
  402. return std::make_pair<Instruction *>(nullptr, nullptr);
  403. SinkLifeStart = !definedInRegion(Blocks, LifeStart);
  404. HoistLifeEnd = !definedInRegion(Blocks, LifeEnd);
  405. // Do legality Check.
  406. if ((SinkLifeStart || HoistLifeEnd) &&
  407. !isLegalToShrinkwrapLifetimeMarkers(Addr))
  408. return std::make_pair<Instruction *>(nullptr, nullptr);
  409. // Check to see if we have a place to do hoisting, if not, bail.
  410. if (HoistLifeEnd && !ExitBlock)
  411. return std::make_pair<Instruction *>(nullptr, nullptr);
  412. return std::make_pair(LifeStart, LifeEnd);
  413. };
  414. bool SinkLifeStart = false, HoistLifeEnd = false;
  415. auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd);
  416. if (Markers.first) {
  417. if (SinkLifeStart)
  418. SinkCands.insert(Markers.first);
  419. SinkCands.insert(AI);
  420. if (HoistLifeEnd)
  421. HoistCands.insert(Markers.second);
  422. continue;
  423. }
  424. // Follow the bitcast.
  425. Instruction *MarkerAddr = nullptr;
  426. for (User *U : AI->users()) {
  427. if (U->stripInBoundsConstantOffsets() == AI) {
  428. SinkLifeStart = false;
  429. HoistLifeEnd = false;
  430. Instruction *Bitcast = cast<Instruction>(U);
  431. Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd);
  432. if (Markers.first) {
  433. MarkerAddr = Bitcast;
  434. continue;
  435. }
  436. }
  437. // Found unknown use of AI.
  438. if (!definedInRegion(Blocks, U)) {
  439. MarkerAddr = nullptr;
  440. break;
  441. }
  442. }
  443. if (MarkerAddr) {
  444. if (SinkLifeStart)
  445. SinkCands.insert(Markers.first);
  446. if (!definedInRegion(Blocks, MarkerAddr))
  447. SinkCands.insert(MarkerAddr);
  448. SinkCands.insert(AI);
  449. if (HoistLifeEnd)
  450. HoistCands.insert(Markers.second);
  451. }
  452. }
  453. }
  454. }
  455. void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
  456. const ValueSet &SinkCands) const {
  457. for (BasicBlock *BB : Blocks) {
  458. // If a used value is defined outside the region, it's an input. If an
  459. // instruction is used outside the region, it's an output.
  460. for (Instruction &II : *BB) {
  461. for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
  462. ++OI) {
  463. Value *V = *OI;
  464. if (!SinkCands.count(V) && definedInCaller(Blocks, V))
  465. Inputs.insert(V);
  466. }
  467. for (User *U : II.users())
  468. if (!definedInRegion(Blocks, U)) {
  469. Outputs.insert(&II);
  470. break;
  471. }
  472. }
  473. }
  474. }
  475. /// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
  476. /// of the region, we need to split the entry block of the region so that the
  477. /// PHI node is easier to deal with.
  478. void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
  479. unsigned NumPredsFromRegion = 0;
  480. unsigned NumPredsOutsideRegion = 0;
  481. if (Header != &Header->getParent()->getEntryBlock()) {
  482. PHINode *PN = dyn_cast<PHINode>(Header->begin());
  483. if (!PN) return; // No PHI nodes.
  484. // If the header node contains any PHI nodes, check to see if there is more
  485. // than one entry from outside the region. If so, we need to sever the
  486. // header block into two.
  487. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
  488. if (Blocks.count(PN->getIncomingBlock(i)))
  489. ++NumPredsFromRegion;
  490. else
  491. ++NumPredsOutsideRegion;
  492. // If there is one (or fewer) predecessor from outside the region, we don't
  493. // need to do anything special.
  494. if (NumPredsOutsideRegion <= 1) return;
  495. }
  496. // Otherwise, we need to split the header block into two pieces: one
  497. // containing PHI nodes merging values from outside of the region, and a
  498. // second that contains all of the code for the block and merges back any
  499. // incoming values from inside of the region.
  500. BasicBlock *NewBB = SplitBlock(Header, Header->getFirstNonPHI(), DT);
  501. // We only want to code extract the second block now, and it becomes the new
  502. // header of the region.
  503. BasicBlock *OldPred = Header;
  504. Blocks.remove(OldPred);
  505. Blocks.insert(NewBB);
  506. Header = NewBB;
  507. // Okay, now we need to adjust the PHI nodes and any branches from within the
  508. // region to go to the new header block instead of the old header block.
  509. if (NumPredsFromRegion) {
  510. PHINode *PN = cast<PHINode>(OldPred->begin());
  511. // Loop over all of the predecessors of OldPred that are in the region,
  512. // changing them to branch to NewBB instead.
  513. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
  514. if (Blocks.count(PN->getIncomingBlock(i))) {
  515. Instruction *TI = PN->getIncomingBlock(i)->getTerminator();
  516. TI->replaceUsesOfWith(OldPred, NewBB);
  517. }
  518. // Okay, everything within the region is now branching to the right block, we
  519. // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
  520. BasicBlock::iterator AfterPHIs;
  521. for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
  522. PHINode *PN = cast<PHINode>(AfterPHIs);
  523. // Create a new PHI node in the new region, which has an incoming value
  524. // from OldPred of PN.
  525. PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
  526. PN->getName() + ".ce", &NewBB->front());
  527. PN->replaceAllUsesWith(NewPN);
  528. NewPN->addIncoming(PN, OldPred);
  529. // Loop over all of the incoming value in PN, moving them to NewPN if they
  530. // are from the extracted region.
  531. for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
  532. if (Blocks.count(PN->getIncomingBlock(i))) {
  533. NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
  534. PN->removeIncomingValue(i);
  535. --i;
  536. }
  537. }
  538. }
  539. }
  540. }
  541. /// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
  542. /// outlined region, we split these PHIs on two: one with inputs from region
  543. /// and other with remaining incoming blocks; then first PHIs are placed in
  544. /// outlined region.
  545. void CodeExtractor::severSplitPHINodesOfExits(
  546. const SmallPtrSetImpl<BasicBlock *> &Exits) {
  547. for (BasicBlock *ExitBB : Exits) {
  548. BasicBlock *NewBB = nullptr;
  549. for (PHINode &PN : ExitBB->phis()) {
  550. // Find all incoming values from the outlining region.
  551. SmallVector<unsigned, 2> IncomingVals;
  552. for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
  553. if (Blocks.count(PN.getIncomingBlock(i)))
  554. IncomingVals.push_back(i);
  555. // Do not process PHI if there is one (or fewer) predecessor from region.
  556. // If PHI has exactly one predecessor from region, only this one incoming
  557. // will be replaced on codeRepl block, so it should be safe to skip PHI.
  558. if (IncomingVals.size() <= 1)
  559. continue;
  560. // Create block for new PHIs and add it to the list of outlined if it
  561. // wasn't done before.
  562. if (!NewBB) {
  563. NewBB = BasicBlock::Create(ExitBB->getContext(),
  564. ExitBB->getName() + ".split",
  565. ExitBB->getParent(), ExitBB);
  566. SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBB),
  567. pred_end(ExitBB));
  568. for (BasicBlock *PredBB : Preds)
  569. if (Blocks.count(PredBB))
  570. PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB);
  571. BranchInst::Create(ExitBB, NewBB);
  572. Blocks.insert(NewBB);
  573. }
  574. // Split this PHI.
  575. PHINode *NewPN =
  576. PHINode::Create(PN.getType(), IncomingVals.size(),
  577. PN.getName() + ".ce", NewBB->getFirstNonPHI());
  578. for (unsigned i : IncomingVals)
  579. NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
  580. for (unsigned i : reverse(IncomingVals))
  581. PN.removeIncomingValue(i, false);
  582. PN.addIncoming(NewPN, NewBB);
  583. }
  584. }
  585. }
  586. void CodeExtractor::splitReturnBlocks() {
  587. for (BasicBlock *Block : Blocks)
  588. if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
  589. BasicBlock *New =
  590. Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
  591. if (DT) {
  592. // Old dominates New. New node dominates all other nodes dominated
  593. // by Old.
  594. DomTreeNode *OldNode = DT->getNode(Block);
  595. SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
  596. OldNode->end());
  597. DomTreeNode *NewNode = DT->addNewBlock(New, Block);
  598. for (DomTreeNode *I : Children)
  599. DT->changeImmediateDominator(I, NewNode);
  600. }
  601. }
  602. }
  603. /// constructFunction - make a function based on inputs and outputs, as follows:
  604. /// f(in0, ..., inN, out0, ..., outN)
  605. Function *CodeExtractor::constructFunction(const ValueSet &inputs,
  606. const ValueSet &outputs,
  607. BasicBlock *header,
  608. BasicBlock *newRootNode,
  609. BasicBlock *newHeader,
  610. Function *oldFunction,
  611. Module *M) {
  612. LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
  613. LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
  614. // This function returns unsigned, outputs will go back by reference.
  615. switch (NumExitBlocks) {
  616. case 0:
  617. case 1: RetTy = Type::getVoidTy(header->getContext()); break;
  618. case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
  619. default: RetTy = Type::getInt16Ty(header->getContext()); break;
  620. }
  621. std::vector<Type *> paramTy;
  622. // Add the types of the input values to the function's argument list
  623. for (Value *value : inputs) {
  624. LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
  625. paramTy.push_back(value->getType());
  626. }
  627. // Add the types of the output values to the function's argument list.
  628. for (Value *output : outputs) {
  629. LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
  630. if (AggregateArgs)
  631. paramTy.push_back(output->getType());
  632. else
  633. paramTy.push_back(PointerType::getUnqual(output->getType()));
  634. }
  635. LLVM_DEBUG({
  636. dbgs() << "Function type: " << *RetTy << " f(";
  637. for (Type *i : paramTy)
  638. dbgs() << *i << ", ";
  639. dbgs() << ")\n";
  640. });
  641. StructType *StructTy;
  642. if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
  643. StructTy = StructType::get(M->getContext(), paramTy);
  644. paramTy.clear();
  645. paramTy.push_back(PointerType::getUnqual(StructTy));
  646. }
  647. FunctionType *funcType =
  648. FunctionType::get(RetTy, paramTy,
  649. AllowVarArgs && oldFunction->isVarArg());
  650. std::string SuffixToUse =
  651. Suffix.empty()
  652. ? (header->getName().empty() ? "extracted" : header->getName().str())
  653. : Suffix;
  654. // Create the new function
  655. Function *newFunction = Function::Create(
  656. funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
  657. oldFunction->getName() + "." + SuffixToUse, M);
  658. // If the old function is no-throw, so is the new one.
  659. if (oldFunction->doesNotThrow())
  660. newFunction->setDoesNotThrow();
  661. // Inherit the uwtable attribute if we need to.
  662. if (oldFunction->hasUWTable())
  663. newFunction->setHasUWTable();
  664. // Inherit all of the target dependent attributes and white-listed
  665. // target independent attributes.
  666. // (e.g. If the extracted region contains a call to an x86.sse
  667. // instruction we need to make sure that the extracted region has the
  668. // "target-features" attribute allowing it to be lowered.
  669. // FIXME: This should be changed to check to see if a specific
  670. // attribute can not be inherited.
  671. for (const auto &Attr : oldFunction->getAttributes().getFnAttributes()) {
  672. if (Attr.isStringAttribute()) {
  673. if (Attr.getKindAsString() == "thunk")
  674. continue;
  675. } else
  676. switch (Attr.getKindAsEnum()) {
  677. // Those attributes cannot be propagated safely. Explicitly list them
  678. // here so we get a warning if new attributes are added. This list also
  679. // includes non-function attributes.
  680. case Attribute::Alignment:
  681. case Attribute::AllocSize:
  682. case Attribute::ArgMemOnly:
  683. case Attribute::Builtin:
  684. case Attribute::ByVal:
  685. case Attribute::Convergent:
  686. case Attribute::Dereferenceable:
  687. case Attribute::DereferenceableOrNull:
  688. case Attribute::InAlloca:
  689. case Attribute::InReg:
  690. case Attribute::InaccessibleMemOnly:
  691. case Attribute::InaccessibleMemOrArgMemOnly:
  692. case Attribute::JumpTable:
  693. case Attribute::Naked:
  694. case Attribute::Nest:
  695. case Attribute::NoAlias:
  696. case Attribute::NoBuiltin:
  697. case Attribute::NoCapture:
  698. case Attribute::NoReturn:
  699. case Attribute::ExpectNoReturn:
  700. case Attribute::None:
  701. case Attribute::NonNull:
  702. case Attribute::ReadNone:
  703. case Attribute::ReadOnly:
  704. case Attribute::Returned:
  705. case Attribute::ReturnsTwice:
  706. case Attribute::SExt:
  707. case Attribute::Speculatable:
  708. case Attribute::StackAlignment:
  709. case Attribute::StructRet:
  710. case Attribute::SwiftError:
  711. case Attribute::SwiftSelf:
  712. case Attribute::WriteOnly:
  713. case Attribute::ZExt:
  714. case Attribute::EndAttrKinds:
  715. continue;
  716. // Those attributes should be safe to propagate to the extracted function.
  717. case Attribute::AlwaysInline:
  718. case Attribute::Cold:
  719. case Attribute::NoRecurse:
  720. case Attribute::InlineHint:
  721. case Attribute::MinSize:
  722. case Attribute::NoDuplicate:
  723. case Attribute::NoImplicitFloat:
  724. case Attribute::NoInline:
  725. case Attribute::NonLazyBind:
  726. case Attribute::NoRedZone:
  727. case Attribute::NoUnwind:
  728. case Attribute::OptForFuzzing:
  729. case Attribute::OptimizeNone:
  730. case Attribute::OptimizeForSize:
  731. case Attribute::SafeStack:
  732. case Attribute::ShadowCallStack:
  733. case Attribute::SanitizeAddress:
  734. case Attribute::SanitizeMemory:
  735. case Attribute::SanitizeThread:
  736. case Attribute::SanitizeHWAddress:
  737. case Attribute::SpeculativeLoadHardening:
  738. case Attribute::StackProtect:
  739. case Attribute::StackProtectReq:
  740. case Attribute::StackProtectStrong:
  741. case Attribute::StrictFP:
  742. case Attribute::UWTable:
  743. case Attribute::NoCfCheck:
  744. break;
  745. }
  746. newFunction->addFnAttr(Attr);
  747. }
  748. newFunction->getBasicBlockList().push_back(newRootNode);
  749. // Create an iterator to name all of the arguments we inserted.
  750. Function::arg_iterator AI = newFunction->arg_begin();
  751. // Rewrite all users of the inputs in the extracted region to use the
  752. // arguments (or appropriate addressing into struct) instead.
  753. for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
  754. Value *RewriteVal;
  755. if (AggregateArgs) {
  756. Value *Idx[2];
  757. Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
  758. Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
  759. Instruction *TI = newFunction->begin()->getTerminator();
  760. GetElementPtrInst *GEP = GetElementPtrInst::Create(
  761. StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
  762. RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI);
  763. } else
  764. RewriteVal = &*AI++;
  765. std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
  766. for (User *use : Users)
  767. if (Instruction *inst = dyn_cast<Instruction>(use))
  768. if (Blocks.count(inst->getParent()))
  769. inst->replaceUsesOfWith(inputs[i], RewriteVal);
  770. }
  771. // Set names for input and output arguments.
  772. if (!AggregateArgs) {
  773. AI = newFunction->arg_begin();
  774. for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
  775. AI->setName(inputs[i]->getName());
  776. for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
  777. AI->setName(outputs[i]->getName()+".out");
  778. }
  779. // Rewrite branches to basic blocks outside of the loop to new dummy blocks
  780. // within the new function. This must be done before we lose track of which
  781. // blocks were originally in the code region.
  782. std::vector<User *> Users(header->user_begin(), header->user_end());
  783. for (unsigned i = 0, e = Users.size(); i != e; ++i)
  784. // The BasicBlock which contains the branch is not in the region
  785. // modify the branch target to a new block
  786. if (Instruction *I = dyn_cast<Instruction>(Users[i]))
  787. if (I->isTerminator() && !Blocks.count(I->getParent()) &&
  788. I->getParent()->getParent() == oldFunction)
  789. I->replaceUsesOfWith(header, newHeader);
  790. return newFunction;
  791. }
  792. /// Scan the extraction region for lifetime markers which reference inputs.
  793. /// Erase these markers. Return the inputs which were referenced.
  794. ///
  795. /// The extraction region is defined by a set of blocks (\p Blocks), and a set
  796. /// of allocas which will be moved from the caller function into the extracted
  797. /// function (\p SunkAllocas).
  798. static SetVector<Value *>
  799. eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
  800. const SetVector<Value *> &SunkAllocas) {
  801. SetVector<Value *> InputObjectsWithLifetime;
  802. for (BasicBlock *BB : Blocks) {
  803. for (auto It = BB->begin(), End = BB->end(); It != End;) {
  804. auto *II = dyn_cast<IntrinsicInst>(&*It);
  805. ++It;
  806. if (!II || !II->isLifetimeStartOrEnd())
  807. continue;
  808. // Get the memory operand of the lifetime marker. If the underlying
  809. // object is a sunk alloca, or is otherwise defined in the extraction
  810. // region, the lifetime marker must not be erased.
  811. Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
  812. if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
  813. continue;
  814. InputObjectsWithLifetime.insert(Mem);
  815. II->eraseFromParent();
  816. }
  817. }
  818. return InputObjectsWithLifetime;
  819. }
  820. /// Insert lifetime start/end markers surrounding the call to the new function
  821. /// for objects defined in the caller.
  822. static void insertLifetimeMarkersSurroundingCall(Module *M,
  823. ArrayRef<Value *> Objects,
  824. CallInst *TheCall) {
  825. if (Objects.empty())
  826. return;
  827. LLVMContext &Ctx = M->getContext();
  828. auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
  829. auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
  830. auto StartFn = llvm::Intrinsic::getDeclaration(
  831. M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
  832. auto EndFn = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::lifetime_end,
  833. Int8PtrTy);
  834. Instruction *Term = TheCall->getParent()->getTerminator();
  835. for (Value *Mem : Objects) {
  836. assert((!isa<Instruction>(Mem) ||
  837. cast<Instruction>(Mem)->getFunction() == TheCall->getFunction()) &&
  838. "Input memory not defined in original function");
  839. Value *MemAsI8Ptr = nullptr;
  840. if (Mem->getType() == Int8PtrTy)
  841. MemAsI8Ptr = Mem;
  842. else
  843. MemAsI8Ptr =
  844. CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
  845. auto StartMarker = CallInst::Create(StartFn, {NegativeOne, MemAsI8Ptr});
  846. StartMarker->insertBefore(TheCall);
  847. auto EndMarker = CallInst::Create(EndFn, {NegativeOne, MemAsI8Ptr});
  848. EndMarker->insertBefore(Term);
  849. }
  850. }
  851. /// emitCallAndSwitchStatement - This method sets up the caller side by adding
  852. /// the call instruction, splitting any PHI nodes in the header block as
  853. /// necessary.
  854. CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
  855. BasicBlock *codeReplacer,
  856. ValueSet &inputs,
  857. ValueSet &outputs) {
  858. // Emit a call to the new function, passing in: *pointer to struct (if
  859. // aggregating parameters), or plan inputs and allocated memory for outputs
  860. std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
  861. Module *M = newFunction->getParent();
  862. LLVMContext &Context = M->getContext();
  863. const DataLayout &DL = M->getDataLayout();
  864. CallInst *call = nullptr;
  865. // Add inputs as params, or to be filled into the struct
  866. for (Value *input : inputs)
  867. if (AggregateArgs)
  868. StructValues.push_back(input);
  869. else
  870. params.push_back(input);
  871. // Create allocas for the outputs
  872. for (Value *output : outputs) {
  873. if (AggregateArgs) {
  874. StructValues.push_back(output);
  875. } else {
  876. AllocaInst *alloca =
  877. new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
  878. nullptr, output->getName() + ".loc",
  879. &codeReplacer->getParent()->front().front());
  880. ReloadOutputs.push_back(alloca);
  881. params.push_back(alloca);
  882. }
  883. }
  884. StructType *StructArgTy = nullptr;
  885. AllocaInst *Struct = nullptr;
  886. if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
  887. std::vector<Type *> ArgTypes;
  888. for (ValueSet::iterator v = StructValues.begin(),
  889. ve = StructValues.end(); v != ve; ++v)
  890. ArgTypes.push_back((*v)->getType());
  891. // Allocate a struct at the beginning of this function
  892. StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
  893. Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
  894. "structArg",
  895. &codeReplacer->getParent()->front().front());
  896. params.push_back(Struct);
  897. for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
  898. Value *Idx[2];
  899. Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
  900. Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
  901. GetElementPtrInst *GEP = GetElementPtrInst::Create(
  902. StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
  903. codeReplacer->getInstList().push_back(GEP);
  904. StoreInst *SI = new StoreInst(StructValues[i], GEP);
  905. codeReplacer->getInstList().push_back(SI);
  906. }
  907. }
  908. // Emit the call to the function
  909. call = CallInst::Create(newFunction, params,
  910. NumExitBlocks > 1 ? "targetBlock" : "");
  911. // Add debug location to the new call, if the original function has debug
  912. // info. In that case, the terminator of the entry block of the extracted
  913. // function contains the first debug location of the extracted function,
  914. // set in extractCodeRegion.
  915. if (codeReplacer->getParent()->getSubprogram()) {
  916. if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
  917. call->setDebugLoc(DL);
  918. }
  919. codeReplacer->getInstList().push_back(call);
  920. Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
  921. unsigned FirstOut = inputs.size();
  922. if (!AggregateArgs)
  923. std::advance(OutputArgBegin, inputs.size());
  924. // Reload the outputs passed in by reference.
  925. Function::arg_iterator OAI = OutputArgBegin;
  926. for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
  927. Value *Output = nullptr;
  928. if (AggregateArgs) {
  929. Value *Idx[2];
  930. Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
  931. Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
  932. GetElementPtrInst *GEP = GetElementPtrInst::Create(
  933. StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
  934. codeReplacer->getInstList().push_back(GEP);
  935. Output = GEP;
  936. } else {
  937. Output = ReloadOutputs[i];
  938. }
  939. LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload");
  940. Reloads.push_back(load);
  941. codeReplacer->getInstList().push_back(load);
  942. std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
  943. for (unsigned u = 0, e = Users.size(); u != e; ++u) {
  944. Instruction *inst = cast<Instruction>(Users[u]);
  945. if (!Blocks.count(inst->getParent()))
  946. inst->replaceUsesOfWith(outputs[i], load);
  947. }
  948. // Store to argument right after the definition of output value.
  949. auto *OutI = dyn_cast<Instruction>(outputs[i]);
  950. if (!OutI)
  951. continue;
  952. // Find proper insertion point.
  953. BasicBlock::iterator InsertPt;
  954. // In case OutI is an invoke, we insert the store at the beginning in the
  955. // 'normal destination' BB. Otherwise we insert the store right after OutI.
  956. if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
  957. InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
  958. else if (auto *Phi = dyn_cast<PHINode>(OutI))
  959. InsertPt = Phi->getParent()->getFirstInsertionPt();
  960. else
  961. InsertPt = std::next(OutI->getIterator());
  962. assert(OAI != newFunction->arg_end() &&
  963. "Number of output arguments should match "
  964. "the amount of defined values");
  965. if (AggregateArgs) {
  966. Value *Idx[2];
  967. Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
  968. Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
  969. GetElementPtrInst *GEP = GetElementPtrInst::Create(
  970. StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), &*InsertPt);
  971. new StoreInst(outputs[i], GEP, &*InsertPt);
  972. // Since there should be only one struct argument aggregating
  973. // all the output values, we shouldn't increment OAI, which always
  974. // points to the struct argument, in this case.
  975. } else {
  976. new StoreInst(outputs[i], &*OAI, &*InsertPt);
  977. ++OAI;
  978. }
  979. }
  980. // Now we can emit a switch statement using the call as a value.
  981. SwitchInst *TheSwitch =
  982. SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
  983. codeReplacer, 0, codeReplacer);
  984. // Since there may be multiple exits from the original region, make the new
  985. // function return an unsigned, switch on that number. This loop iterates
  986. // over all of the blocks in the extracted region, updating any terminator
  987. // instructions in the to-be-extracted region that branch to blocks that are
  988. // not in the region to be extracted.
  989. std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
  990. unsigned switchVal = 0;
  991. for (BasicBlock *Block : Blocks) {
  992. Instruction *TI = Block->getTerminator();
  993. for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
  994. if (!Blocks.count(TI->getSuccessor(i))) {
  995. BasicBlock *OldTarget = TI->getSuccessor(i);
  996. // add a new basic block which returns the appropriate value
  997. BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
  998. if (!NewTarget) {
  999. // If we don't already have an exit stub for this non-extracted
  1000. // destination, create one now!
  1001. NewTarget = BasicBlock::Create(Context,
  1002. OldTarget->getName() + ".exitStub",
  1003. newFunction);
  1004. unsigned SuccNum = switchVal++;
  1005. Value *brVal = nullptr;
  1006. switch (NumExitBlocks) {
  1007. case 0:
  1008. case 1: break; // No value needed.
  1009. case 2: // Conditional branch, return a bool
  1010. brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
  1011. break;
  1012. default:
  1013. brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
  1014. break;
  1015. }
  1016. ReturnInst::Create(Context, brVal, NewTarget);
  1017. // Update the switch instruction.
  1018. TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
  1019. SuccNum),
  1020. OldTarget);
  1021. }
  1022. // rewrite the original branch instruction with this new target
  1023. TI->setSuccessor(i, NewTarget);
  1024. }
  1025. }
  1026. // Now that we've done the deed, simplify the switch instruction.
  1027. Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
  1028. switch (NumExitBlocks) {
  1029. case 0:
  1030. // There are no successors (the block containing the switch itself), which
  1031. // means that previously this was the last part of the function, and hence
  1032. // this should be rewritten as a `ret'
  1033. // Check if the function should return a value
  1034. if (OldFnRetTy->isVoidTy()) {
  1035. ReturnInst::Create(Context, nullptr, TheSwitch); // Return void
  1036. } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
  1037. // return what we have
  1038. ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
  1039. } else {
  1040. // Otherwise we must have code extracted an unwind or something, just
  1041. // return whatever we want.
  1042. ReturnInst::Create(Context,
  1043. Constant::getNullValue(OldFnRetTy), TheSwitch);
  1044. }
  1045. TheSwitch->eraseFromParent();
  1046. break;
  1047. case 1:
  1048. // Only a single destination, change the switch into an unconditional
  1049. // branch.
  1050. BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
  1051. TheSwitch->eraseFromParent();
  1052. break;
  1053. case 2:
  1054. BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
  1055. call, TheSwitch);
  1056. TheSwitch->eraseFromParent();
  1057. break;
  1058. default:
  1059. // Otherwise, make the default destination of the switch instruction be one
  1060. // of the other successors.
  1061. TheSwitch->setCondition(call);
  1062. TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
  1063. // Remove redundant case
  1064. TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
  1065. break;
  1066. }
  1067. // Insert lifetime markers around the reloads of any output values. The
  1068. // allocas output values are stored in are only in-use in the codeRepl block.
  1069. insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, call);
  1070. return call;
  1071. }
  1072. void CodeExtractor::moveCodeToFunction(Function *newFunction) {
  1073. Function *oldFunc = (*Blocks.begin())->getParent();
  1074. Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
  1075. Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
  1076. for (BasicBlock *Block : Blocks) {
  1077. // Delete the basic block from the old function, and the list of blocks
  1078. oldBlocks.remove(Block);
  1079. // Insert this basic block into the new function
  1080. newBlocks.push_back(Block);
  1081. }
  1082. }
  1083. void CodeExtractor::calculateNewCallTerminatorWeights(
  1084. BasicBlock *CodeReplacer,
  1085. DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
  1086. BranchProbabilityInfo *BPI) {
  1087. using Distribution = BlockFrequencyInfoImplBase::Distribution;
  1088. using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
  1089. // Update the branch weights for the exit block.
  1090. Instruction *TI = CodeReplacer->getTerminator();
  1091. SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
  1092. // Block Frequency distribution with dummy node.
  1093. Distribution BranchDist;
  1094. // Add each of the frequencies of the successors.
  1095. for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
  1096. BlockNode ExitNode(i);
  1097. uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
  1098. if (ExitFreq != 0)
  1099. BranchDist.addExit(ExitNode, ExitFreq);
  1100. else
  1101. BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
  1102. }
  1103. // Check for no total weight.
  1104. if (BranchDist.Total == 0)
  1105. return;
  1106. // Normalize the distribution so that they can fit in unsigned.
  1107. BranchDist.normalize();
  1108. // Create normalized branch weights and set the metadata.
  1109. for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
  1110. const auto &Weight = BranchDist.Weights[I];
  1111. // Get the weight and update the current BFI.
  1112. BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
  1113. BranchProbability BP(Weight.Amount, BranchDist.Total);
  1114. BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
  1115. }
  1116. TI->setMetadata(
  1117. LLVMContext::MD_prof,
  1118. MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
  1119. }
  1120. Function *CodeExtractor::extractCodeRegion() {
  1121. if (!isEligible())
  1122. return nullptr;
  1123. // Assumption: this is a single-entry code region, and the header is the first
  1124. // block in the region.
  1125. BasicBlock *header = *Blocks.begin();
  1126. Function *oldFunction = header->getParent();
  1127. // For functions with varargs, check that varargs handling is only done in the
  1128. // outlined function, i.e vastart and vaend are only used in outlined blocks.
  1129. if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) {
  1130. auto containsVarArgIntrinsic = [](Instruction &I) {
  1131. if (const CallInst *CI = dyn_cast<CallInst>(&I))
  1132. if (const Function *F = CI->getCalledFunction())
  1133. return F->getIntrinsicID() == Intrinsic::vastart ||
  1134. F->getIntrinsicID() == Intrinsic::vaend;
  1135. return false;
  1136. };
  1137. for (auto &BB : *oldFunction) {
  1138. if (Blocks.count(&BB))
  1139. continue;
  1140. if (llvm::any_of(BB, containsVarArgIntrinsic))
  1141. return nullptr;
  1142. }
  1143. }
  1144. ValueSet inputs, outputs, SinkingCands, HoistingCands;
  1145. BasicBlock *CommonExit = nullptr;
  1146. // Calculate the entry frequency of the new function before we change the root
  1147. // block.
  1148. BlockFrequency EntryFreq;
  1149. if (BFI) {
  1150. assert(BPI && "Both BPI and BFI are required to preserve profile info");
  1151. for (BasicBlock *Pred : predecessors(header)) {
  1152. if (Blocks.count(Pred))
  1153. continue;
  1154. EntryFreq +=
  1155. BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
  1156. }
  1157. }
  1158. // If we have any return instructions in the region, split those blocks so
  1159. // that the return is not in the region.
  1160. splitReturnBlocks();
  1161. // Calculate the exit blocks for the extracted region and the total exit
  1162. // weights for each of those blocks.
  1163. DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
  1164. SmallPtrSet<BasicBlock *, 1> ExitBlocks;
  1165. for (BasicBlock *Block : Blocks) {
  1166. for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
  1167. ++SI) {
  1168. if (!Blocks.count(*SI)) {
  1169. // Update the branch weight for this successor.
  1170. if (BFI) {
  1171. BlockFrequency &BF = ExitWeights[*SI];
  1172. BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
  1173. }
  1174. ExitBlocks.insert(*SI);
  1175. }
  1176. }
  1177. }
  1178. NumExitBlocks = ExitBlocks.size();
  1179. // If we have to split PHI nodes of the entry or exit blocks, do so now.
  1180. severSplitPHINodesOfEntry(header);
  1181. severSplitPHINodesOfExits(ExitBlocks);
  1182. // This takes place of the original loop
  1183. BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
  1184. "codeRepl", oldFunction,
  1185. header);
  1186. // The new function needs a root node because other nodes can branch to the
  1187. // head of the region, but the entry node of a function cannot have preds.
  1188. BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
  1189. "newFuncRoot");
  1190. auto *BranchI = BranchInst::Create(header);
  1191. // If the original function has debug info, we have to add a debug location
  1192. // to the new branch instruction from the artificial entry block.
  1193. // We use the debug location of the first instruction in the extracted
  1194. // blocks, as there is no other equivalent line in the source code.
  1195. if (oldFunction->getSubprogram()) {
  1196. any_of(Blocks, [&BranchI](const BasicBlock *BB) {
  1197. return any_of(*BB, [&BranchI](const Instruction &I) {
  1198. if (!I.getDebugLoc())
  1199. return false;
  1200. BranchI->setDebugLoc(I.getDebugLoc());
  1201. return true;
  1202. });
  1203. });
  1204. }
  1205. newFuncRoot->getInstList().push_back(BranchI);
  1206. findAllocas(SinkingCands, HoistingCands, CommonExit);
  1207. assert(HoistingCands.empty() || CommonExit);
  1208. // Find inputs to, outputs from the code region.
  1209. findInputsOutputs(inputs, outputs, SinkingCands);
  1210. // Now sink all instructions which only have non-phi uses inside the region
  1211. for (auto *II : SinkingCands)
  1212. cast<Instruction>(II)->moveBefore(*newFuncRoot,
  1213. newFuncRoot->getFirstInsertionPt());
  1214. if (!HoistingCands.empty()) {
  1215. auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
  1216. Instruction *TI = HoistToBlock->getTerminator();
  1217. for (auto *II : HoistingCands)
  1218. cast<Instruction>(II)->moveBefore(TI);
  1219. }
  1220. // Collect objects which are inputs to the extraction region and also
  1221. // referenced by lifetime start/end markers within it. The effects of these
  1222. // markers must be replicated in the calling function to prevent the stack
  1223. // coloring pass from merging slots which store input objects.
  1224. ValueSet InputObjectsWithLifetime =
  1225. eraseLifetimeMarkersOnInputs(Blocks, SinkingCands);
  1226. // Construct new function based on inputs/outputs & add allocas for all defs.
  1227. Function *newFunction =
  1228. constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
  1229. oldFunction, oldFunction->getParent());
  1230. // Update the entry count of the function.
  1231. if (BFI) {
  1232. auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
  1233. if (Count.hasValue())
  1234. newFunction->setEntryCount(
  1235. ProfileCount(Count.getValue(), Function::PCT_Real)); // FIXME
  1236. BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
  1237. }
  1238. CallInst *TheCall =
  1239. emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
  1240. moveCodeToFunction(newFunction);
  1241. // Replicate the effects of any lifetime start/end markers which referenced
  1242. // input objects in the extraction region by placing markers around the call.
  1243. insertLifetimeMarkersSurroundingCall(oldFunction->getParent(),
  1244. InputObjectsWithLifetime.getArrayRef(),
  1245. TheCall);
  1246. // Propagate personality info to the new function if there is one.
  1247. if (oldFunction->hasPersonalityFn())
  1248. newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
  1249. // Update the branch weights for the exit block.
  1250. if (BFI && NumExitBlocks > 1)
  1251. calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
  1252. // Loop over all of the PHI nodes in the header and exit blocks, and change
  1253. // any references to the old incoming edge to be the new incoming edge.
  1254. for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
  1255. PHINode *PN = cast<PHINode>(I);
  1256. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
  1257. if (!Blocks.count(PN->getIncomingBlock(i)))
  1258. PN->setIncomingBlock(i, newFuncRoot);
  1259. }
  1260. for (BasicBlock *ExitBB : ExitBlocks)
  1261. for (PHINode &PN : ExitBB->phis()) {
  1262. Value *IncomingCodeReplacerVal = nullptr;
  1263. for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
  1264. // Ignore incoming values from outside of the extracted region.
  1265. if (!Blocks.count(PN.getIncomingBlock(i)))
  1266. continue;
  1267. // Ensure that there is only one incoming value from codeReplacer.
  1268. if (!IncomingCodeReplacerVal) {
  1269. PN.setIncomingBlock(i, codeReplacer);
  1270. IncomingCodeReplacerVal = PN.getIncomingValue(i);
  1271. } else
  1272. assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
  1273. "PHI has two incompatbile incoming values from codeRepl");
  1274. }
  1275. }
  1276. // Erase debug info intrinsics. Variable updates within the new function are
  1277. // invisible to debuggers. This could be improved by defining a DISubprogram
  1278. // for the new function.
  1279. for (BasicBlock &BB : *newFunction) {
  1280. auto BlockIt = BB.begin();
  1281. // Remove debug info intrinsics from the new function.
  1282. while (BlockIt != BB.end()) {
  1283. Instruction *Inst = &*BlockIt;
  1284. ++BlockIt;
  1285. if (isa<DbgInfoIntrinsic>(Inst))
  1286. Inst->eraseFromParent();
  1287. }
  1288. // Remove debug info intrinsics which refer to values in the new function
  1289. // from the old function.
  1290. SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
  1291. for (Instruction &I : BB)
  1292. findDbgUsers(DbgUsers, &I);
  1293. for (DbgVariableIntrinsic *DVI : DbgUsers)
  1294. DVI->eraseFromParent();
  1295. }
  1296. // Mark the new function `noreturn` if applicable. Terminators which resume
  1297. // exception propagation are treated as returning instructions. This is to
  1298. // avoid inserting traps after calls to outlined functions which unwind.
  1299. bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) {
  1300. const Instruction *Term = BB.getTerminator();
  1301. return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
  1302. });
  1303. if (doesNotReturn)
  1304. newFunction->setDoesNotReturn();
  1305. LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) {
  1306. newFunction->dump();
  1307. report_fatal_error("verification of newFunction failed!");
  1308. });
  1309. LLVM_DEBUG(if (verifyFunction(*oldFunction))
  1310. report_fatal_error("verification of oldFunction failed!"));
  1311. return newFunction;
  1312. }