CoroFrame.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. //===- CoroFrame.cpp - Builds and manipulates coroutine frame -------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. // This file contains classes used to discover if for a particular value
  9. // there from sue to definition that crosses a suspend block.
  10. //
  11. // Using the information discovered we form a Coroutine Frame structure to
  12. // contain those values. All uses of those values are replaced with appropriate
  13. // GEP + load from the coroutine frame. At the point of the definition we spill
  14. // the value into the coroutine frame.
  15. //
  16. // TODO: pack values tightly using liveness info.
  17. //===----------------------------------------------------------------------===//
  18. #include "CoroInternal.h"
  19. #include "llvm/ADT/BitVector.h"
  20. #include "llvm/Transforms/Utils/Local.h"
  21. #include "llvm/Config/llvm-config.h"
  22. #include "llvm/IR/CFG.h"
  23. #include "llvm/IR/Dominators.h"
  24. #include "llvm/IR/IRBuilder.h"
  25. #include "llvm/IR/InstIterator.h"
  26. #include "llvm/Support/Debug.h"
  27. #include "llvm/Support/MathExtras.h"
  28. #include "llvm/Support/circular_raw_ostream.h"
  29. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  30. using namespace llvm;
  31. // The "coro-suspend-crossing" flag is very noisy. There is another debug type,
  32. // "coro-frame", which results in leaner debug spew.
  33. #define DEBUG_TYPE "coro-suspend-crossing"
  34. enum { SmallVectorThreshold = 32 };
  35. // Provides two way mapping between the blocks and numbers.
  36. namespace {
  37. class BlockToIndexMapping {
  38. SmallVector<BasicBlock *, SmallVectorThreshold> V;
  39. public:
  40. size_t size() const { return V.size(); }
  41. BlockToIndexMapping(Function &F) {
  42. for (BasicBlock &BB : F)
  43. V.push_back(&BB);
  44. llvm::sort(V);
  45. }
  46. size_t blockToIndex(BasicBlock *BB) const {
  47. auto *I = llvm::lower_bound(V, BB);
  48. assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
  49. return I - V.begin();
  50. }
  51. BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; }
  52. };
  53. } // end anonymous namespace
  54. // The SuspendCrossingInfo maintains data that allows to answer a question
  55. // whether given two BasicBlocks A and B there is a path from A to B that
  56. // passes through a suspend point.
  57. //
  58. // For every basic block 'i' it maintains a BlockData that consists of:
  59. // Consumes: a bit vector which contains a set of indices of blocks that can
  60. // reach block 'i'
  61. // Kills: a bit vector which contains a set of indices of blocks that can
  62. // reach block 'i', but one of the path will cross a suspend point
  63. // Suspend: a boolean indicating whether block 'i' contains a suspend point.
  64. // End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
  65. //
  66. namespace {
  67. struct SuspendCrossingInfo {
  68. BlockToIndexMapping Mapping;
  69. struct BlockData {
  70. BitVector Consumes;
  71. BitVector Kills;
  72. bool Suspend = false;
  73. bool End = false;
  74. };
  75. SmallVector<BlockData, SmallVectorThreshold> Block;
  76. iterator_range<succ_iterator> successors(BlockData const &BD) const {
  77. BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
  78. return llvm::successors(BB);
  79. }
  80. BlockData &getBlockData(BasicBlock *BB) {
  81. return Block[Mapping.blockToIndex(BB)];
  82. }
  83. void dump() const;
  84. void dump(StringRef Label, BitVector const &BV) const;
  85. SuspendCrossingInfo(Function &F, coro::Shape &Shape);
  86. bool hasPathCrossingSuspendPoint(BasicBlock *DefBB, BasicBlock *UseBB) const {
  87. size_t const DefIndex = Mapping.blockToIndex(DefBB);
  88. size_t const UseIndex = Mapping.blockToIndex(UseBB);
  89. assert(Block[UseIndex].Consumes[DefIndex] && "use must consume def");
  90. bool const Result = Block[UseIndex].Kills[DefIndex];
  91. LLVM_DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName()
  92. << " answer is " << Result << "\n");
  93. return Result;
  94. }
  95. bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
  96. auto *I = cast<Instruction>(U);
  97. // We rewrote PHINodes, so that only the ones with exactly one incoming
  98. // value need to be analyzed.
  99. if (auto *PN = dyn_cast<PHINode>(I))
  100. if (PN->getNumIncomingValues() > 1)
  101. return false;
  102. BasicBlock *UseBB = I->getParent();
  103. return hasPathCrossingSuspendPoint(DefBB, UseBB);
  104. }
  105. bool isDefinitionAcrossSuspend(Argument &A, User *U) const {
  106. return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U);
  107. }
  108. bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
  109. return isDefinitionAcrossSuspend(I.getParent(), U);
  110. }
  111. };
  112. } // end anonymous namespace
  113. #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
  114. LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(StringRef Label,
  115. BitVector const &BV) const {
  116. dbgs() << Label << ":";
  117. for (size_t I = 0, N = BV.size(); I < N; ++I)
  118. if (BV[I])
  119. dbgs() << " " << Mapping.indexToBlock(I)->getName();
  120. dbgs() << "\n";
  121. }
  122. LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
  123. for (size_t I = 0, N = Block.size(); I < N; ++I) {
  124. BasicBlock *const B = Mapping.indexToBlock(I);
  125. dbgs() << B->getName() << ":\n";
  126. dump(" Consumes", Block[I].Consumes);
  127. dump(" Kills", Block[I].Kills);
  128. }
  129. dbgs() << "\n";
  130. }
  131. #endif
  132. SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
  133. : Mapping(F) {
  134. const size_t N = Mapping.size();
  135. Block.resize(N);
  136. // Initialize every block so that it consumes itself
  137. for (size_t I = 0; I < N; ++I) {
  138. auto &B = Block[I];
  139. B.Consumes.resize(N);
  140. B.Kills.resize(N);
  141. B.Consumes.set(I);
  142. }
  143. // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
  144. // the code beyond coro.end is reachable during initial invocation of the
  145. // coroutine.
  146. for (auto *CE : Shape.CoroEnds)
  147. getBlockData(CE->getParent()).End = true;
  148. // Mark all suspend blocks and indicate that they kill everything they
  149. // consume. Note, that crossing coro.save also requires a spill, as any code
  150. // between coro.save and coro.suspend may resume the coroutine and all of the
  151. // state needs to be saved by that time.
  152. auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
  153. BasicBlock *SuspendBlock = BarrierInst->getParent();
  154. auto &B = getBlockData(SuspendBlock);
  155. B.Suspend = true;
  156. B.Kills |= B.Consumes;
  157. };
  158. for (CoroSuspendInst *CSI : Shape.CoroSuspends) {
  159. markSuspendBlock(CSI);
  160. markSuspendBlock(CSI->getCoroSave());
  161. }
  162. // Iterate propagating consumes and kills until they stop changing.
  163. int Iteration = 0;
  164. (void)Iteration;
  165. bool Changed;
  166. do {
  167. LLVM_DEBUG(dbgs() << "iteration " << ++Iteration);
  168. LLVM_DEBUG(dbgs() << "==============\n");
  169. Changed = false;
  170. for (size_t I = 0; I < N; ++I) {
  171. auto &B = Block[I];
  172. for (BasicBlock *SI : successors(B)) {
  173. auto SuccNo = Mapping.blockToIndex(SI);
  174. // Saved Consumes and Kills bitsets so that it is easy to see
  175. // if anything changed after propagation.
  176. auto &S = Block[SuccNo];
  177. auto SavedConsumes = S.Consumes;
  178. auto SavedKills = S.Kills;
  179. // Propagate Kills and Consumes from block B into its successor S.
  180. S.Consumes |= B.Consumes;
  181. S.Kills |= B.Kills;
  182. // If block B is a suspend block, it should propagate kills into the
  183. // its successor for every block B consumes.
  184. if (B.Suspend) {
  185. S.Kills |= B.Consumes;
  186. }
  187. if (S.Suspend) {
  188. // If block S is a suspend block, it should kill all of the blocks it
  189. // consumes.
  190. S.Kills |= S.Consumes;
  191. } else if (S.End) {
  192. // If block S is an end block, it should not propagate kills as the
  193. // blocks following coro.end() are reached during initial invocation
  194. // of the coroutine while all the data are still available on the
  195. // stack or in the registers.
  196. S.Kills.reset();
  197. } else {
  198. // This is reached when S block it not Suspend nor coro.end and it
  199. // need to make sure that it is not in the kill set.
  200. S.Kills.reset(SuccNo);
  201. }
  202. // See if anything changed.
  203. Changed |= (S.Kills != SavedKills) || (S.Consumes != SavedConsumes);
  204. if (S.Kills != SavedKills) {
  205. LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI->getName()
  206. << "\n");
  207. LLVM_DEBUG(dump("S.Kills", S.Kills));
  208. LLVM_DEBUG(dump("SavedKills", SavedKills));
  209. }
  210. if (S.Consumes != SavedConsumes) {
  211. LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI << "\n");
  212. LLVM_DEBUG(dump("S.Consume", S.Consumes));
  213. LLVM_DEBUG(dump("SavedCons", SavedConsumes));
  214. }
  215. }
  216. }
  217. } while (Changed);
  218. LLVM_DEBUG(dump());
  219. }
  220. #undef DEBUG_TYPE // "coro-suspend-crossing"
  221. #define DEBUG_TYPE "coro-frame"
  222. // We build up the list of spills for every case where a use is separated
  223. // from the definition by a suspend point.
  224. namespace {
  225. class Spill {
  226. Value *Def = nullptr;
  227. Instruction *User = nullptr;
  228. unsigned FieldNo = 0;
  229. public:
  230. Spill(Value *Def, llvm::User *U) : Def(Def), User(cast<Instruction>(U)) {}
  231. Value *def() const { return Def; }
  232. Instruction *user() const { return User; }
  233. BasicBlock *userBlock() const { return User->getParent(); }
  234. // Note that field index is stored in the first SpillEntry for a particular
  235. // definition. Subsequent mentions of a defintion do not have fieldNo
  236. // assigned. This works out fine as the users of Spills capture the info about
  237. // the definition the first time they encounter it. Consider refactoring
  238. // SpillInfo into two arrays to normalize the spill representation.
  239. unsigned fieldIndex() const {
  240. assert(FieldNo && "Accessing unassigned field");
  241. return FieldNo;
  242. }
  243. void setFieldIndex(unsigned FieldNumber) {
  244. assert(!FieldNo && "Reassigning field number");
  245. FieldNo = FieldNumber;
  246. }
  247. };
  248. } // namespace
  249. // Note that there may be more than one record with the same value of Def in
  250. // the SpillInfo vector.
  251. using SpillInfo = SmallVector<Spill, 8>;
  252. #ifndef NDEBUG
  253. static void dump(StringRef Title, SpillInfo const &Spills) {
  254. dbgs() << "------------- " << Title << "--------------\n";
  255. Value *CurrentValue = nullptr;
  256. for (auto const &E : Spills) {
  257. if (CurrentValue != E.def()) {
  258. CurrentValue = E.def();
  259. CurrentValue->dump();
  260. }
  261. dbgs() << " user: ";
  262. E.user()->dump();
  263. }
  264. }
  265. #endif
  266. namespace {
  267. // We cannot rely solely on natural alignment of a type when building a
  268. // coroutine frame and if the alignment specified on the Alloca instruction
  269. // differs from the natural alignment of the alloca type we will need to insert
  270. // padding.
  271. struct PaddingCalculator {
  272. const DataLayout &DL;
  273. LLVMContext &Context;
  274. unsigned StructSize = 0;
  275. PaddingCalculator(LLVMContext &Context, DataLayout const &DL)
  276. : DL(DL), Context(Context) {}
  277. // Replicate the logic from IR/DataLayout.cpp to match field offset
  278. // computation for LLVM structs.
  279. void addType(Type *Ty) {
  280. unsigned TyAlign = DL.getABITypeAlignment(Ty);
  281. if ((StructSize & (TyAlign - 1)) != 0)
  282. StructSize = alignTo(StructSize, TyAlign);
  283. StructSize += DL.getTypeAllocSize(Ty); // Consume space for this data item.
  284. }
  285. void addTypes(SmallVectorImpl<Type *> const &Types) {
  286. for (auto *Ty : Types)
  287. addType(Ty);
  288. }
  289. unsigned computePadding(Type *Ty, unsigned ForcedAlignment) {
  290. unsigned TyAlign = DL.getABITypeAlignment(Ty);
  291. auto Natural = alignTo(StructSize, TyAlign);
  292. auto Forced = alignTo(StructSize, ForcedAlignment);
  293. // Return how many bytes of padding we need to insert.
  294. if (Natural != Forced)
  295. return std::max(Natural, Forced) - StructSize;
  296. // Rely on natural alignment.
  297. return 0;
  298. }
  299. // If padding required, return the padding field type to insert.
  300. ArrayType *getPaddingType(Type *Ty, unsigned ForcedAlignment) {
  301. if (auto Padding = computePadding(Ty, ForcedAlignment))
  302. return ArrayType::get(Type::getInt8Ty(Context), Padding);
  303. return nullptr;
  304. }
  305. };
  306. } // namespace
  307. // Build a struct that will keep state for an active coroutine.
  308. // struct f.frame {
  309. // ResumeFnTy ResumeFnAddr;
  310. // ResumeFnTy DestroyFnAddr;
  311. // int ResumeIndex;
  312. // ... promise (if present) ...
  313. // ... spills ...
  314. // };
  315. static StructType *buildFrameType(Function &F, coro::Shape &Shape,
  316. SpillInfo &Spills) {
  317. LLVMContext &C = F.getContext();
  318. const DataLayout &DL = F.getParent()->getDataLayout();
  319. PaddingCalculator Padder(C, DL);
  320. SmallString<32> Name(F.getName());
  321. Name.append(".Frame");
  322. StructType *FrameTy = StructType::create(C, Name);
  323. auto *FramePtrTy = FrameTy->getPointerTo();
  324. auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
  325. /*IsVarArgs=*/false);
  326. auto *FnPtrTy = FnTy->getPointerTo();
  327. // Figure out how wide should be an integer type storing the suspend index.
  328. unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size()));
  329. Type *PromiseType = Shape.PromiseAlloca
  330. ? Shape.PromiseAlloca->getType()->getElementType()
  331. : Type::getInt1Ty(C);
  332. SmallVector<Type *, 8> Types{FnPtrTy, FnPtrTy, PromiseType,
  333. Type::getIntNTy(C, IndexBits)};
  334. Value *CurrentDef = nullptr;
  335. Padder.addTypes(Types);
  336. // Create an entry for every spilled value.
  337. for (auto &S : Spills) {
  338. if (CurrentDef == S.def())
  339. continue;
  340. CurrentDef = S.def();
  341. // PromiseAlloca was already added to Types array earlier.
  342. if (CurrentDef == Shape.PromiseAlloca)
  343. continue;
  344. uint64_t Count = 1;
  345. Type *Ty = nullptr;
  346. if (auto *AI = dyn_cast<AllocaInst>(CurrentDef)) {
  347. Ty = AI->getAllocatedType();
  348. if (unsigned AllocaAlignment = AI->getAlignment()) {
  349. // If alignment is specified in alloca, see if we need to insert extra
  350. // padding.
  351. if (auto PaddingTy = Padder.getPaddingType(Ty, AllocaAlignment)) {
  352. Types.push_back(PaddingTy);
  353. Padder.addType(PaddingTy);
  354. }
  355. }
  356. if (auto *CI = dyn_cast<ConstantInt>(AI->getArraySize()))
  357. Count = CI->getValue().getZExtValue();
  358. else
  359. report_fatal_error("Coroutines cannot handle non static allocas yet");
  360. } else {
  361. Ty = CurrentDef->getType();
  362. }
  363. S.setFieldIndex(Types.size());
  364. if (Count == 1)
  365. Types.push_back(Ty);
  366. else
  367. Types.push_back(ArrayType::get(Ty, Count));
  368. Padder.addType(Ty);
  369. }
  370. FrameTy->setBody(Types);
  371. return FrameTy;
  372. }
  373. // We need to make room to insert a spill after initial PHIs, but before
  374. // catchswitch instruction. Placing it before violates the requirement that
  375. // catchswitch, like all other EHPads must be the first nonPHI in a block.
  376. //
  377. // Split away catchswitch into a separate block and insert in its place:
  378. //
  379. // cleanuppad <InsertPt> cleanupret.
  380. //
  381. // cleanupret instruction will act as an insert point for the spill.
  382. static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) {
  383. BasicBlock *CurrentBlock = CatchSwitch->getParent();
  384. BasicBlock *NewBlock = CurrentBlock->splitBasicBlock(CatchSwitch);
  385. CurrentBlock->getTerminator()->eraseFromParent();
  386. auto *CleanupPad =
  387. CleanupPadInst::Create(CatchSwitch->getParentPad(), {}, "", CurrentBlock);
  388. auto *CleanupRet =
  389. CleanupReturnInst::Create(CleanupPad, NewBlock, CurrentBlock);
  390. return CleanupRet;
  391. }
  392. // Replace all alloca and SSA values that are accessed across suspend points
  393. // with GetElementPointer from coroutine frame + loads and stores. Create an
  394. // AllocaSpillBB that will become the new entry block for the resume parts of
  395. // the coroutine:
  396. //
  397. // %hdl = coro.begin(...)
  398. // whatever
  399. //
  400. // becomes:
  401. //
  402. // %hdl = coro.begin(...)
  403. // %FramePtr = bitcast i8* hdl to %f.frame*
  404. // br label %AllocaSpillBB
  405. //
  406. // AllocaSpillBB:
  407. // ; geps corresponding to allocas that were moved to coroutine frame
  408. // br label PostSpill
  409. //
  410. // PostSpill:
  411. // whatever
  412. //
  413. //
  414. static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) {
  415. auto *CB = Shape.CoroBegin;
  416. LLVMContext &C = CB->getContext();
  417. IRBuilder<> Builder(CB->getNextNode());
  418. StructType *FrameTy = Shape.FrameTy;
  419. PointerType *FramePtrTy = FrameTy->getPointerTo();
  420. auto *FramePtr =
  421. cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr"));
  422. Value *CurrentValue = nullptr;
  423. BasicBlock *CurrentBlock = nullptr;
  424. Value *CurrentReload = nullptr;
  425. unsigned Index = 0; // Proper field number will be read from field definition.
  426. // We need to keep track of any allocas that need "spilling"
  427. // since they will live in the coroutine frame now, all access to them
  428. // need to be changed, not just the access across suspend points
  429. // we remember allocas and their indices to be handled once we processed
  430. // all the spills.
  431. SmallVector<std::pair<AllocaInst *, unsigned>, 4> Allocas;
  432. // Promise alloca (if present) has a fixed field number (Shape::PromiseField)
  433. if (Shape.PromiseAlloca)
  434. Allocas.emplace_back(Shape.PromiseAlloca, coro::Shape::PromiseField);
  435. // Create a GEP with the given index into the coroutine frame for the original
  436. // value Orig. Appends an extra 0 index for array-allocas, preserving the
  437. // original type.
  438. auto GetFramePointer = [&](uint32_t Index, Value *Orig) -> Value * {
  439. SmallVector<Value *, 3> Indices = {
  440. ConstantInt::get(Type::getInt32Ty(C), 0),
  441. ConstantInt::get(Type::getInt32Ty(C), Index),
  442. };
  443. if (auto *AI = dyn_cast<AllocaInst>(Orig)) {
  444. if (auto *CI = dyn_cast<ConstantInt>(AI->getArraySize())) {
  445. auto Count = CI->getValue().getZExtValue();
  446. if (Count > 1) {
  447. Indices.push_back(ConstantInt::get(Type::getInt32Ty(C), 0));
  448. }
  449. } else {
  450. report_fatal_error("Coroutines cannot handle non static allocas yet");
  451. }
  452. }
  453. return Builder.CreateInBoundsGEP(FrameTy, FramePtr, Indices);
  454. };
  455. // Create a load instruction to reload the spilled value from the coroutine
  456. // frame.
  457. auto CreateReload = [&](Instruction *InsertBefore) {
  458. assert(Index && "accessing unassigned field number");
  459. Builder.SetInsertPoint(InsertBefore);
  460. auto *G = GetFramePointer(Index, CurrentValue);
  461. G->setName(CurrentValue->getName() + Twine(".reload.addr"));
  462. return isa<AllocaInst>(CurrentValue)
  463. ? G
  464. : Builder.CreateLoad(FrameTy->getElementType(Index), G,
  465. CurrentValue->getName() + Twine(".reload"));
  466. };
  467. for (auto const &E : Spills) {
  468. // If we have not seen the value, generate a spill.
  469. if (CurrentValue != E.def()) {
  470. CurrentValue = E.def();
  471. CurrentBlock = nullptr;
  472. CurrentReload = nullptr;
  473. Index = E.fieldIndex();
  474. if (auto *AI = dyn_cast<AllocaInst>(CurrentValue)) {
  475. // Spilled AllocaInst will be replaced with GEP from the coroutine frame
  476. // there is no spill required.
  477. Allocas.emplace_back(AI, Index);
  478. if (!AI->isStaticAlloca())
  479. report_fatal_error("Coroutines cannot handle non static allocas yet");
  480. } else {
  481. // Otherwise, create a store instruction storing the value into the
  482. // coroutine frame.
  483. Instruction *InsertPt = nullptr;
  484. if (isa<Argument>(CurrentValue)) {
  485. // For arguments, we will place the store instruction right after
  486. // the coroutine frame pointer instruction, i.e. bitcast of
  487. // coro.begin from i8* to %f.frame*.
  488. InsertPt = FramePtr->getNextNode();
  489. } else if (auto *II = dyn_cast<InvokeInst>(CurrentValue)) {
  490. // If we are spilling the result of the invoke instruction, split the
  491. // normal edge and insert the spill in the new block.
  492. auto NewBB = SplitEdge(II->getParent(), II->getNormalDest());
  493. InsertPt = NewBB->getTerminator();
  494. } else if (dyn_cast<PHINode>(CurrentValue)) {
  495. // Skip the PHINodes and EH pads instructions.
  496. BasicBlock *DefBlock = cast<Instruction>(E.def())->getParent();
  497. if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator()))
  498. InsertPt = splitBeforeCatchSwitch(CSI);
  499. else
  500. InsertPt = &*DefBlock->getFirstInsertionPt();
  501. } else {
  502. // For all other values, the spill is placed immediately after
  503. // the definition.
  504. assert(!cast<Instruction>(E.def())->isTerminator() &&
  505. "unexpected terminator");
  506. InsertPt = cast<Instruction>(E.def())->getNextNode();
  507. }
  508. Builder.SetInsertPoint(InsertPt);
  509. auto *G = Builder.CreateConstInBoundsGEP2_32(
  510. FrameTy, FramePtr, 0, Index,
  511. CurrentValue->getName() + Twine(".spill.addr"));
  512. Builder.CreateStore(CurrentValue, G);
  513. }
  514. }
  515. // If we have not seen the use block, generate a reload in it.
  516. if (CurrentBlock != E.userBlock()) {
  517. CurrentBlock = E.userBlock();
  518. CurrentReload = CreateReload(&*CurrentBlock->getFirstInsertionPt());
  519. }
  520. // If we have a single edge PHINode, remove it and replace it with a reload
  521. // from the coroutine frame. (We already took care of multi edge PHINodes
  522. // by rewriting them in the rewritePHIs function).
  523. if (auto *PN = dyn_cast<PHINode>(E.user())) {
  524. assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
  525. "values in the PHINode");
  526. PN->replaceAllUsesWith(CurrentReload);
  527. PN->eraseFromParent();
  528. continue;
  529. }
  530. // Replace all uses of CurrentValue in the current instruction with reload.
  531. E.user()->replaceUsesOfWith(CurrentValue, CurrentReload);
  532. }
  533. BasicBlock *FramePtrBB = FramePtr->getParent();
  534. Shape.AllocaSpillBlock =
  535. FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB");
  536. Shape.AllocaSpillBlock->splitBasicBlock(&Shape.AllocaSpillBlock->front(),
  537. "PostSpill");
  538. Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front());
  539. // If we found any allocas, replace all of their remaining uses with Geps.
  540. for (auto &P : Allocas) {
  541. auto *G = GetFramePointer(P.second, P.first);
  542. // We are not using ReplaceInstWithInst(P.first, cast<Instruction>(G)) here,
  543. // as we are changing location of the instruction.
  544. G->takeName(P.first);
  545. P.first->replaceAllUsesWith(G);
  546. P.first->eraseFromParent();
  547. }
  548. return FramePtr;
  549. }
  550. // Sets the unwind edge of an instruction to a particular successor.
  551. static void setUnwindEdgeTo(Instruction *TI, BasicBlock *Succ) {
  552. if (auto *II = dyn_cast<InvokeInst>(TI))
  553. II->setUnwindDest(Succ);
  554. else if (auto *CS = dyn_cast<CatchSwitchInst>(TI))
  555. CS->setUnwindDest(Succ);
  556. else if (auto *CR = dyn_cast<CleanupReturnInst>(TI))
  557. CR->setUnwindDest(Succ);
  558. else
  559. llvm_unreachable("unexpected terminator instruction");
  560. }
  561. // Replaces all uses of OldPred with the NewPred block in all PHINodes in a
  562. // block.
  563. static void updatePhiNodes(BasicBlock *DestBB, BasicBlock *OldPred,
  564. BasicBlock *NewPred,
  565. PHINode *LandingPadReplacement) {
  566. unsigned BBIdx = 0;
  567. for (BasicBlock::iterator I = DestBB->begin(); isa<PHINode>(I); ++I) {
  568. PHINode *PN = cast<PHINode>(I);
  569. // We manually update the LandingPadReplacement PHINode and it is the last
  570. // PHI Node. So, if we find it, we are done.
  571. if (LandingPadReplacement == PN)
  572. break;
  573. // Reuse the previous value of BBIdx if it lines up. In cases where we
  574. // have multiple phi nodes with *lots* of predecessors, this is a speed
  575. // win because we don't have to scan the PHI looking for TIBB. This
  576. // happens because the BB list of PHI nodes are usually in the same
  577. // order.
  578. if (PN->getIncomingBlock(BBIdx) != OldPred)
  579. BBIdx = PN->getBasicBlockIndex(OldPred);
  580. assert(BBIdx != (unsigned)-1 && "Invalid PHI Index!");
  581. PN->setIncomingBlock(BBIdx, NewPred);
  582. }
  583. }
  584. // Uses SplitEdge unless the successor block is an EHPad, in which case do EH
  585. // specific handling.
  586. static BasicBlock *ehAwareSplitEdge(BasicBlock *BB, BasicBlock *Succ,
  587. LandingPadInst *OriginalPad,
  588. PHINode *LandingPadReplacement) {
  589. auto *PadInst = Succ->getFirstNonPHI();
  590. if (!LandingPadReplacement && !PadInst->isEHPad())
  591. return SplitEdge(BB, Succ);
  592. auto *NewBB = BasicBlock::Create(BB->getContext(), "", BB->getParent(), Succ);
  593. setUnwindEdgeTo(BB->getTerminator(), NewBB);
  594. updatePhiNodes(Succ, BB, NewBB, LandingPadReplacement);
  595. if (LandingPadReplacement) {
  596. auto *NewLP = OriginalPad->clone();
  597. auto *Terminator = BranchInst::Create(Succ, NewBB);
  598. NewLP->insertBefore(Terminator);
  599. LandingPadReplacement->addIncoming(NewLP, NewBB);
  600. return NewBB;
  601. }
  602. Value *ParentPad = nullptr;
  603. if (auto *FuncletPad = dyn_cast<FuncletPadInst>(PadInst))
  604. ParentPad = FuncletPad->getParentPad();
  605. else if (auto *CatchSwitch = dyn_cast<CatchSwitchInst>(PadInst))
  606. ParentPad = CatchSwitch->getParentPad();
  607. else
  608. llvm_unreachable("handling for other EHPads not implemented yet");
  609. auto *NewCleanupPad = CleanupPadInst::Create(ParentPad, {}, "", NewBB);
  610. CleanupReturnInst::Create(NewCleanupPad, Succ, NewBB);
  611. return NewBB;
  612. }
  613. static void rewritePHIs(BasicBlock &BB) {
  614. // For every incoming edge we will create a block holding all
  615. // incoming values in a single PHI nodes.
  616. //
  617. // loop:
  618. // %n.val = phi i32[%n, %entry], [%inc, %loop]
  619. //
  620. // It will create:
  621. //
  622. // loop.from.entry:
  623. // %n.loop.pre = phi i32 [%n, %entry]
  624. // br %label loop
  625. // loop.from.loop:
  626. // %inc.loop.pre = phi i32 [%inc, %loop]
  627. // br %label loop
  628. //
  629. // After this rewrite, further analysis will ignore any phi nodes with more
  630. // than one incoming edge.
  631. // TODO: Simplify PHINodes in the basic block to remove duplicate
  632. // predecessors.
  633. LandingPadInst *LandingPad = nullptr;
  634. PHINode *ReplPHI = nullptr;
  635. if ((LandingPad = dyn_cast_or_null<LandingPadInst>(BB.getFirstNonPHI()))) {
  636. // ehAwareSplitEdge will clone the LandingPad in all the edge blocks.
  637. // We replace the original landing pad with a PHINode that will collect the
  638. // results from all of them.
  639. ReplPHI = PHINode::Create(LandingPad->getType(), 1, "", LandingPad);
  640. ReplPHI->takeName(LandingPad);
  641. LandingPad->replaceAllUsesWith(ReplPHI);
  642. // We will erase the original landing pad at the end of this function after
  643. // ehAwareSplitEdge cloned it in the transition blocks.
  644. }
  645. SmallVector<BasicBlock *, 8> Preds(pred_begin(&BB), pred_end(&BB));
  646. for (BasicBlock *Pred : Preds) {
  647. auto *IncomingBB = ehAwareSplitEdge(Pred, &BB, LandingPad, ReplPHI);
  648. IncomingBB->setName(BB.getName() + Twine(".from.") + Pred->getName());
  649. auto *PN = cast<PHINode>(&BB.front());
  650. do {
  651. int Index = PN->getBasicBlockIndex(IncomingBB);
  652. Value *V = PN->getIncomingValue(Index);
  653. PHINode *InputV = PHINode::Create(
  654. V->getType(), 1, V->getName() + Twine(".") + BB.getName(),
  655. &IncomingBB->front());
  656. InputV->addIncoming(V, Pred);
  657. PN->setIncomingValue(Index, InputV);
  658. PN = dyn_cast<PHINode>(PN->getNextNode());
  659. } while (PN != ReplPHI); // ReplPHI is either null or the PHI that replaced
  660. // the landing pad.
  661. }
  662. if (LandingPad) {
  663. // Calls to ehAwareSplitEdge function cloned the original lading pad.
  664. // No longer need it.
  665. LandingPad->eraseFromParent();
  666. }
  667. }
  668. static void rewritePHIs(Function &F) {
  669. SmallVector<BasicBlock *, 8> WorkList;
  670. for (BasicBlock &BB : F)
  671. if (auto *PN = dyn_cast<PHINode>(&BB.front()))
  672. if (PN->getNumIncomingValues() > 1)
  673. WorkList.push_back(&BB);
  674. for (BasicBlock *BB : WorkList)
  675. rewritePHIs(*BB);
  676. }
  677. // Check for instructions that we can recreate on resume as opposed to spill
  678. // the result into a coroutine frame.
  679. static bool materializable(Instruction &V) {
  680. return isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
  681. isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V);
  682. }
  683. // Check for structural coroutine intrinsics that should not be spilled into
  684. // the coroutine frame.
  685. static bool isCoroutineStructureIntrinsic(Instruction &I) {
  686. return isa<CoroIdInst>(&I) || isa<CoroSaveInst>(&I) ||
  687. isa<CoroSuspendInst>(&I);
  688. }
  689. // For every use of the value that is across suspend point, recreate that value
  690. // after a suspend point.
  691. static void rewriteMaterializableInstructions(IRBuilder<> &IRB,
  692. SpillInfo const &Spills) {
  693. BasicBlock *CurrentBlock = nullptr;
  694. Instruction *CurrentMaterialization = nullptr;
  695. Instruction *CurrentDef = nullptr;
  696. for (auto const &E : Spills) {
  697. // If it is a new definition, update CurrentXXX variables.
  698. if (CurrentDef != E.def()) {
  699. CurrentDef = cast<Instruction>(E.def());
  700. CurrentBlock = nullptr;
  701. CurrentMaterialization = nullptr;
  702. }
  703. // If we have not seen this block, materialize the value.
  704. if (CurrentBlock != E.userBlock()) {
  705. CurrentBlock = E.userBlock();
  706. CurrentMaterialization = cast<Instruction>(CurrentDef)->clone();
  707. CurrentMaterialization->setName(CurrentDef->getName());
  708. CurrentMaterialization->insertBefore(
  709. &*CurrentBlock->getFirstInsertionPt());
  710. }
  711. if (auto *PN = dyn_cast<PHINode>(E.user())) {
  712. assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
  713. "values in the PHINode");
  714. PN->replaceAllUsesWith(CurrentMaterialization);
  715. PN->eraseFromParent();
  716. continue;
  717. }
  718. // Replace all uses of CurrentDef in the current instruction with the
  719. // CurrentMaterialization for the block.
  720. E.user()->replaceUsesOfWith(CurrentDef, CurrentMaterialization);
  721. }
  722. }
  723. // Move early uses of spilled variable after CoroBegin.
  724. // For example, if a parameter had address taken, we may end up with the code
  725. // like:
  726. // define @f(i32 %n) {
  727. // %n.addr = alloca i32
  728. // store %n, %n.addr
  729. // ...
  730. // call @coro.begin
  731. // we need to move the store after coro.begin
  732. static void moveSpillUsesAfterCoroBegin(Function &F, SpillInfo const &Spills,
  733. CoroBeginInst *CoroBegin) {
  734. DominatorTree DT(F);
  735. SmallVector<Instruction *, 8> NeedsMoving;
  736. Value *CurrentValue = nullptr;
  737. for (auto const &E : Spills) {
  738. if (CurrentValue == E.def())
  739. continue;
  740. CurrentValue = E.def();
  741. for (User *U : CurrentValue->users()) {
  742. Instruction *I = cast<Instruction>(U);
  743. if (!DT.dominates(CoroBegin, I)) {
  744. LLVM_DEBUG(dbgs() << "will move: " << *I << "\n");
  745. // TODO: Make this more robust. Currently if we run into a situation
  746. // where simple instruction move won't work we panic and
  747. // report_fatal_error.
  748. for (User *UI : I->users()) {
  749. if (!DT.dominates(CoroBegin, cast<Instruction>(UI)))
  750. report_fatal_error("cannot move instruction since its users are not"
  751. " dominated by CoroBegin");
  752. }
  753. NeedsMoving.push_back(I);
  754. }
  755. }
  756. }
  757. Instruction *InsertPt = CoroBegin->getNextNode();
  758. for (Instruction *I : NeedsMoving)
  759. I->moveBefore(InsertPt);
  760. }
  761. // Splits the block at a particular instruction unless it is the first
  762. // instruction in the block with a single predecessor.
  763. static BasicBlock *splitBlockIfNotFirst(Instruction *I, const Twine &Name) {
  764. auto *BB = I->getParent();
  765. if (&BB->front() == I) {
  766. if (BB->getSinglePredecessor()) {
  767. BB->setName(Name);
  768. return BB;
  769. }
  770. }
  771. return BB->splitBasicBlock(I, Name);
  772. }
  773. // Split above and below a particular instruction so that it
  774. // will be all alone by itself in a block.
  775. static void splitAround(Instruction *I, const Twine &Name) {
  776. splitBlockIfNotFirst(I, Name);
  777. splitBlockIfNotFirst(I->getNextNode(), "After" + Name);
  778. }
  779. void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
  780. // Lower coro.dbg.declare to coro.dbg.value, since we are going to rewrite
  781. // access to local variables.
  782. LowerDbgDeclare(F);
  783. Shape.PromiseAlloca = Shape.CoroBegin->getId()->getPromise();
  784. if (Shape.PromiseAlloca) {
  785. Shape.CoroBegin->getId()->clearPromise();
  786. }
  787. // Make sure that all coro.save, coro.suspend and the fallthrough coro.end
  788. // intrinsics are in their own blocks to simplify the logic of building up
  789. // SuspendCrossing data.
  790. for (CoroSuspendInst *CSI : Shape.CoroSuspends) {
  791. splitAround(CSI->getCoroSave(), "CoroSave");
  792. splitAround(CSI, "CoroSuspend");
  793. }
  794. // Put CoroEnds into their own blocks.
  795. for (CoroEndInst *CE : Shape.CoroEnds)
  796. splitAround(CE, "CoroEnd");
  797. // Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will
  798. // never has its definition separated from the PHI by the suspend point.
  799. rewritePHIs(F);
  800. // Build suspend crossing info.
  801. SuspendCrossingInfo Checker(F, Shape);
  802. IRBuilder<> Builder(F.getContext());
  803. SpillInfo Spills;
  804. for (int Repeat = 0; Repeat < 4; ++Repeat) {
  805. // See if there are materializable instructions across suspend points.
  806. for (Instruction &I : instructions(F))
  807. if (materializable(I))
  808. for (User *U : I.users())
  809. if (Checker.isDefinitionAcrossSuspend(I, U))
  810. Spills.emplace_back(&I, U);
  811. if (Spills.empty())
  812. break;
  813. // Rewrite materializable instructions to be materialized at the use point.
  814. LLVM_DEBUG(dump("Materializations", Spills));
  815. rewriteMaterializableInstructions(Builder, Spills);
  816. Spills.clear();
  817. }
  818. // Collect the spills for arguments and other not-materializable values.
  819. for (Argument &A : F.args())
  820. for (User *U : A.users())
  821. if (Checker.isDefinitionAcrossSuspend(A, U))
  822. Spills.emplace_back(&A, U);
  823. for (Instruction &I : instructions(F)) {
  824. // Values returned from coroutine structure intrinsics should not be part
  825. // of the Coroutine Frame.
  826. if (isCoroutineStructureIntrinsic(I) || &I == Shape.CoroBegin)
  827. continue;
  828. // The Coroutine Promise always included into coroutine frame, no need to
  829. // check for suspend crossing.
  830. if (Shape.PromiseAlloca == &I)
  831. continue;
  832. for (User *U : I.users())
  833. if (Checker.isDefinitionAcrossSuspend(I, U)) {
  834. // We cannot spill a token.
  835. if (I.getType()->isTokenTy())
  836. report_fatal_error(
  837. "token definition is separated from the use by a suspend point");
  838. Spills.emplace_back(&I, U);
  839. }
  840. }
  841. LLVM_DEBUG(dump("Spills", Spills));
  842. moveSpillUsesAfterCoroBegin(F, Spills, Shape.CoroBegin);
  843. Shape.FrameTy = buildFrameType(F, Shape, Spills);
  844. Shape.FramePtr = insertSpills(Spills, Shape);
  845. }