CoroSplit.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959
  1. //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. // This pass builds the coroutine frame and outlines resume and destroy parts
  9. // of the coroutine into separate functions.
  10. //
  11. // We present a coroutine to an LLVM as an ordinary function with suspension
  12. // points marked up with intrinsics. We let the optimizer party on the coroutine
  13. // as a single function for as long as possible. Shortly before the coroutine is
  14. // eligible to be inlined into its callers, we split up the coroutine into parts
  15. // corresponding to an initial, resume and destroy invocations of the coroutine,
  16. // add them to the current SCC and restart the IPO pipeline to optimize the
  17. // coroutine subfunctions we extracted before proceeding to the caller of the
  18. // coroutine.
  19. //===----------------------------------------------------------------------===//
  20. #include "CoroInstr.h"
  21. #include "CoroInternal.h"
  22. #include "llvm/ADT/DenseMap.h"
  23. #include "llvm/ADT/SmallPtrSet.h"
  24. #include "llvm/ADT/SmallVector.h"
  25. #include "llvm/ADT/StringRef.h"
  26. #include "llvm/ADT/Twine.h"
  27. #include "llvm/Analysis/CallGraph.h"
  28. #include "llvm/Analysis/CallGraphSCCPass.h"
  29. #include "llvm/Transforms/Utils/Local.h"
  30. #include "llvm/IR/Argument.h"
  31. #include "llvm/IR/Attributes.h"
  32. #include "llvm/IR/BasicBlock.h"
  33. #include "llvm/IR/CFG.h"
  34. #include "llvm/IR/CallSite.h"
  35. #include "llvm/IR/CallingConv.h"
  36. #include "llvm/IR/Constants.h"
  37. #include "llvm/IR/DataLayout.h"
  38. #include "llvm/IR/DerivedTypes.h"
  39. #include "llvm/IR/Function.h"
  40. #include "llvm/IR/GlobalValue.h"
  41. #include "llvm/IR/GlobalVariable.h"
  42. #include "llvm/IR/IRBuilder.h"
  43. #include "llvm/IR/InstIterator.h"
  44. #include "llvm/IR/InstrTypes.h"
  45. #include "llvm/IR/Instruction.h"
  46. #include "llvm/IR/Instructions.h"
  47. #include "llvm/IR/IntrinsicInst.h"
  48. #include "llvm/IR/LLVMContext.h"
  49. #include "llvm/IR/LegacyPassManager.h"
  50. #include "llvm/IR/Module.h"
  51. #include "llvm/IR/Type.h"
  52. #include "llvm/IR/Value.h"
  53. #include "llvm/IR/Verifier.h"
  54. #include "llvm/Pass.h"
  55. #include "llvm/Support/Casting.h"
  56. #include "llvm/Support/Debug.h"
  57. #include "llvm/Support/raw_ostream.h"
  58. #include "llvm/Transforms/Scalar.h"
  59. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  60. #include "llvm/Transforms/Utils/Cloning.h"
  61. #include "llvm/Transforms/Utils/ValueMapper.h"
  62. #include <cassert>
  63. #include <cstddef>
  64. #include <cstdint>
  65. #include <initializer_list>
  66. #include <iterator>
  67. using namespace llvm;
  68. #define DEBUG_TYPE "coro-split"
  69. // Create an entry block for a resume function with a switch that will jump to
  70. // suspend points.
  71. static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
  72. LLVMContext &C = F.getContext();
  73. // resume.entry:
  74. // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
  75. // i32 2
  76. // % index = load i32, i32* %index.addr
  77. // switch i32 %index, label %unreachable [
  78. // i32 0, label %resume.0
  79. // i32 1, label %resume.1
  80. // ...
  81. // ]
  82. auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
  83. auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
  84. IRBuilder<> Builder(NewEntry);
  85. auto *FramePtr = Shape.FramePtr;
  86. auto *FrameTy = Shape.FrameTy;
  87. auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
  88. FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
  89. auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
  90. auto *Switch =
  91. Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
  92. Shape.ResumeSwitch = Switch;
  93. size_t SuspendIndex = 0;
  94. for (CoroSuspendInst *S : Shape.CoroSuspends) {
  95. ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
  96. // Replace CoroSave with a store to Index:
  97. // %index.addr = getelementptr %f.frame... (index field number)
  98. // store i32 0, i32* %index.addr1
  99. auto *Save = S->getCoroSave();
  100. Builder.SetInsertPoint(Save);
  101. if (S->isFinal()) {
  102. // Final suspend point is represented by storing zero in ResumeFnAddr.
  103. auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0,
  104. 0, "ResumeFn.addr");
  105. auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
  106. cast<PointerType>(GepIndex->getType())->getElementType()));
  107. Builder.CreateStore(NullPtr, GepIndex);
  108. } else {
  109. auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
  110. FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
  111. Builder.CreateStore(IndexVal, GepIndex);
  112. }
  113. Save->replaceAllUsesWith(ConstantTokenNone::get(C));
  114. Save->eraseFromParent();
  115. // Split block before and after coro.suspend and add a jump from an entry
  116. // switch:
  117. //
  118. // whateverBB:
  119. // whatever
  120. // %0 = call i8 @llvm.coro.suspend(token none, i1 false)
  121. // switch i8 %0, label %suspend[i8 0, label %resume
  122. // i8 1, label %cleanup]
  123. // becomes:
  124. //
  125. // whateverBB:
  126. // whatever
  127. // br label %resume.0.landing
  128. //
  129. // resume.0: ; <--- jump from the switch in the resume.entry
  130. // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
  131. // br label %resume.0.landing
  132. //
  133. // resume.0.landing:
  134. // %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
  135. // switch i8 % 1, label %suspend [i8 0, label %resume
  136. // i8 1, label %cleanup]
  137. auto *SuspendBB = S->getParent();
  138. auto *ResumeBB =
  139. SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
  140. auto *LandingBB = ResumeBB->splitBasicBlock(
  141. S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
  142. Switch->addCase(IndexVal, ResumeBB);
  143. cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
  144. auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
  145. S->replaceAllUsesWith(PN);
  146. PN->addIncoming(Builder.getInt8(-1), SuspendBB);
  147. PN->addIncoming(S, ResumeBB);
  148. ++SuspendIndex;
  149. }
  150. Builder.SetInsertPoint(UnreachBB);
  151. Builder.CreateUnreachable();
  152. return NewEntry;
  153. }
  154. // In Resumers, we replace fallthrough coro.end with ret void and delete the
  155. // rest of the block.
  156. static void replaceFallthroughCoroEnd(IntrinsicInst *End,
  157. ValueToValueMapTy &VMap) {
  158. auto *NewE = cast<IntrinsicInst>(VMap[End]);
  159. ReturnInst::Create(NewE->getContext(), nullptr, NewE);
  160. // Remove the rest of the block, by splitting it into an unreachable block.
  161. auto *BB = NewE->getParent();
  162. BB->splitBasicBlock(NewE);
  163. BB->getTerminator()->eraseFromParent();
  164. }
  165. // In Resumers, we replace unwind coro.end with True to force the immediate
  166. // unwind to caller.
  167. static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
  168. if (Shape.CoroEnds.empty())
  169. return;
  170. LLVMContext &Context = Shape.CoroEnds.front()->getContext();
  171. auto *True = ConstantInt::getTrue(Context);
  172. for (CoroEndInst *CE : Shape.CoroEnds) {
  173. if (!CE->isUnwind())
  174. continue;
  175. auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
  176. // If coro.end has an associated bundle, add cleanupret instruction.
  177. if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
  178. Value *FromPad = Bundle->Inputs[0];
  179. auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
  180. NewCE->getParent()->splitBasicBlock(NewCE);
  181. CleanupRet->getParent()->getTerminator()->eraseFromParent();
  182. }
  183. NewCE->replaceAllUsesWith(True);
  184. NewCE->eraseFromParent();
  185. }
  186. }
  187. // Rewrite final suspend point handling. We do not use suspend index to
  188. // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
  189. // coroutine frame, since it is undefined behavior to resume a coroutine
  190. // suspended at the final suspend point. Thus, in the resume function, we can
  191. // simply remove the last case (when coro::Shape is built, the final suspend
  192. // point (if present) is always the last element of CoroSuspends array).
  193. // In the destroy function, we add a code sequence to check if ResumeFnAddress
  194. // is Null, and if so, jump to the appropriate label to handle cleanup from the
  195. // final suspend point.
  196. static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
  197. coro::Shape &Shape, SwitchInst *Switch,
  198. bool IsDestroy) {
  199. assert(Shape.HasFinalSuspend);
  200. auto FinalCaseIt = std::prev(Switch->case_end());
  201. BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
  202. Switch->removeCase(FinalCaseIt);
  203. if (IsDestroy) {
  204. BasicBlock *OldSwitchBB = Switch->getParent();
  205. auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
  206. Builder.SetInsertPoint(OldSwitchBB->getTerminator());
  207. auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr,
  208. 0, 0, "ResumeFn.addr");
  209. auto *Load = Builder.CreateLoad(
  210. Shape.FrameTy->getElementType(coro::Shape::ResumeField), GepIndex);
  211. auto *NullPtr =
  212. ConstantPointerNull::get(cast<PointerType>(Load->getType()));
  213. auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
  214. Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
  215. OldSwitchBB->getTerminator()->eraseFromParent();
  216. }
  217. }
  218. // Create a resume clone by cloning the body of the original function, setting
  219. // new entry block and replacing coro.suspend an appropriate value to force
  220. // resume or cleanup pass for every suspend point.
  221. static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
  222. BasicBlock *ResumeEntry, int8_t FnIndex) {
  223. Module *M = F.getParent();
  224. auto *FrameTy = Shape.FrameTy;
  225. auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
  226. auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
  227. Function *NewF =
  228. Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage,
  229. F.getName() + Suffix, M);
  230. NewF->addParamAttr(0, Attribute::NonNull);
  231. NewF->addParamAttr(0, Attribute::NoAlias);
  232. ValueToValueMapTy VMap;
  233. // Replace all args with undefs. The buildCoroutineFrame algorithm already
  234. // rewritten access to the args that occurs after suspend points with loads
  235. // and stores to/from the coroutine frame.
  236. for (Argument &A : F.args())
  237. VMap[&A] = UndefValue::get(A.getType());
  238. SmallVector<ReturnInst *, 4> Returns;
  239. CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
  240. NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
  241. // Remove old returns.
  242. for (ReturnInst *Return : Returns)
  243. changeToUnreachable(Return, /*UseLLVMTrap=*/false);
  244. // Remove old return attributes.
  245. NewF->removeAttributes(
  246. AttributeList::ReturnIndex,
  247. AttributeFuncs::typeIncompatible(NewF->getReturnType()));
  248. // Make AllocaSpillBlock the new entry block.
  249. auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
  250. auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
  251. Entry->moveBefore(&NewF->getEntryBlock());
  252. Entry->getTerminator()->eraseFromParent();
  253. BranchInst::Create(SwitchBB, Entry);
  254. Entry->setName("entry" + Suffix);
  255. // Clear all predecessors of the new entry block.
  256. auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
  257. Entry->replaceAllUsesWith(Switch->getDefaultDest());
  258. IRBuilder<> Builder(&NewF->getEntryBlock().front());
  259. // Remap frame pointer.
  260. Argument *NewFramePtr = &*NewF->arg_begin();
  261. Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
  262. NewFramePtr->takeName(OldFramePtr);
  263. OldFramePtr->replaceAllUsesWith(NewFramePtr);
  264. // Remap vFrame pointer.
  265. auto *NewVFrame = Builder.CreateBitCast(
  266. NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
  267. Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
  268. OldVFrame->replaceAllUsesWith(NewVFrame);
  269. // Rewrite final suspend handling as it is not done via switch (allows to
  270. // remove final case from the switch, since it is undefined behavior to resume
  271. // the coroutine suspended at the final suspend point.
  272. if (Shape.HasFinalSuspend) {
  273. auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
  274. bool IsDestroy = FnIndex != 0;
  275. handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy);
  276. }
  277. // Replace coro suspend with the appropriate resume index.
  278. // Replacing coro.suspend with (0) will result in control flow proceeding to
  279. // a resume label associated with a suspend point, replacing it with (1) will
  280. // result in control flow proceeding to a cleanup label associated with this
  281. // suspend point.
  282. auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0);
  283. for (CoroSuspendInst *CS : Shape.CoroSuspends) {
  284. auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
  285. MappedCS->replaceAllUsesWith(NewValue);
  286. MappedCS->eraseFromParent();
  287. }
  288. // Remove coro.end intrinsics.
  289. replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
  290. replaceUnwindCoroEnds(Shape, VMap);
  291. // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
  292. // to suppress deallocation code.
  293. coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
  294. /*Elide=*/FnIndex == 2);
  295. NewF->setCallingConv(CallingConv::Fast);
  296. return NewF;
  297. }
  298. static void removeCoroEnds(coro::Shape &Shape) {
  299. if (Shape.CoroEnds.empty())
  300. return;
  301. LLVMContext &Context = Shape.CoroEnds.front()->getContext();
  302. auto *False = ConstantInt::getFalse(Context);
  303. for (CoroEndInst *CE : Shape.CoroEnds) {
  304. CE->replaceAllUsesWith(False);
  305. CE->eraseFromParent();
  306. }
  307. }
  308. static void replaceFrameSize(coro::Shape &Shape) {
  309. if (Shape.CoroSizes.empty())
  310. return;
  311. // In the same function all coro.sizes should have the same result type.
  312. auto *SizeIntrin = Shape.CoroSizes.back();
  313. Module *M = SizeIntrin->getModule();
  314. const DataLayout &DL = M->getDataLayout();
  315. auto Size = DL.getTypeAllocSize(Shape.FrameTy);
  316. auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
  317. for (CoroSizeInst *CS : Shape.CoroSizes) {
  318. CS->replaceAllUsesWith(SizeConstant);
  319. CS->eraseFromParent();
  320. }
  321. }
  322. // Create a global constant array containing pointers to functions provided and
  323. // set Info parameter of CoroBegin to point at this constant. Example:
  324. //
  325. // @f.resumers = internal constant [2 x void(%f.frame*)*]
  326. // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
  327. // define void @f() {
  328. // ...
  329. // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
  330. // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
  331. //
  332. // Assumes that all the functions have the same signature.
  333. static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
  334. std::initializer_list<Function *> Fns) {
  335. SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
  336. assert(!Args.empty());
  337. Function *Part = *Fns.begin();
  338. Module *M = Part->getParent();
  339. auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
  340. auto *ConstVal = ConstantArray::get(ArrTy, Args);
  341. auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
  342. GlobalVariable::PrivateLinkage, ConstVal,
  343. F.getName() + Twine(".resumers"));
  344. // Update coro.begin instruction to refer to this constant.
  345. LLVMContext &C = F.getContext();
  346. auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
  347. CoroBegin->getId()->setInfo(BC);
  348. }
  349. // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
  350. static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
  351. Function *DestroyFn, Function *CleanupFn) {
  352. IRBuilder<> Builder(Shape.FramePtr->getNextNode());
  353. auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32(
  354. Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField,
  355. "resume.addr");
  356. Builder.CreateStore(ResumeFn, ResumeAddr);
  357. Value *DestroyOrCleanupFn = DestroyFn;
  358. CoroIdInst *CoroId = Shape.CoroBegin->getId();
  359. if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
  360. // If there is a CoroAlloc and it returns false (meaning we elide the
  361. // allocation, use CleanupFn instead of DestroyFn).
  362. DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
  363. }
  364. auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32(
  365. Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField,
  366. "destroy.addr");
  367. Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
  368. }
  369. static void postSplitCleanup(Function &F) {
  370. removeUnreachableBlocks(F);
  371. legacy::FunctionPassManager FPM(F.getParent());
  372. FPM.add(createVerifierPass());
  373. FPM.add(createSCCPPass());
  374. FPM.add(createCFGSimplificationPass());
  375. FPM.add(createEarlyCSEPass());
  376. FPM.add(createCFGSimplificationPass());
  377. FPM.doInitialization();
  378. FPM.run(F);
  379. FPM.doFinalization();
  380. }
  381. // Assuming we arrived at the block NewBlock from Prev instruction, store
  382. // PHI's incoming values in the ResolvedValues map.
  383. static void
  384. scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
  385. DenseMap<Value *, Value *> &ResolvedValues) {
  386. auto *PrevBB = Prev->getParent();
  387. for (PHINode &PN : NewBlock->phis()) {
  388. auto V = PN.getIncomingValueForBlock(PrevBB);
  389. // See if we already resolved it.
  390. auto VI = ResolvedValues.find(V);
  391. if (VI != ResolvedValues.end())
  392. V = VI->second;
  393. // Remember the value.
  394. ResolvedValues[&PN] = V;
  395. }
  396. }
  397. // Replace a sequence of branches leading to a ret, with a clone of a ret
  398. // instruction. Suspend instruction represented by a switch, track the PHI
  399. // values and select the correct case successor when possible.
  400. static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
  401. DenseMap<Value *, Value *> ResolvedValues;
  402. Instruction *I = InitialInst;
  403. while (I->isTerminator()) {
  404. if (isa<ReturnInst>(I)) {
  405. if (I != InitialInst)
  406. ReplaceInstWithInst(InitialInst, I->clone());
  407. return true;
  408. }
  409. if (auto *BR = dyn_cast<BranchInst>(I)) {
  410. if (BR->isUnconditional()) {
  411. BasicBlock *BB = BR->getSuccessor(0);
  412. scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
  413. I = BB->getFirstNonPHIOrDbgOrLifetime();
  414. continue;
  415. }
  416. } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
  417. Value *V = SI->getCondition();
  418. auto it = ResolvedValues.find(V);
  419. if (it != ResolvedValues.end())
  420. V = it->second;
  421. if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
  422. BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
  423. scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
  424. I = BB->getFirstNonPHIOrDbgOrLifetime();
  425. continue;
  426. }
  427. }
  428. return false;
  429. }
  430. return false;
  431. }
  432. // Add musttail to any resume instructions that is immediately followed by a
  433. // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
  434. // for symmetrical coroutine control transfer (C++ Coroutines TS extension).
  435. // This transformation is done only in the resume part of the coroutine that has
  436. // identical signature and calling convention as the coro.resume call.
  437. static void addMustTailToCoroResumes(Function &F) {
  438. bool changed = false;
  439. // Collect potential resume instructions.
  440. SmallVector<CallInst *, 4> Resumes;
  441. for (auto &I : instructions(F))
  442. if (auto *Call = dyn_cast<CallInst>(&I))
  443. if (auto *CalledValue = Call->getCalledValue())
  444. // CoroEarly pass replaced coro resumes with indirect calls to an
  445. // address return by CoroSubFnInst intrinsic. See if it is one of those.
  446. if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts()))
  447. Resumes.push_back(Call);
  448. // Set musttail on those that are followed by a ret instruction.
  449. for (CallInst *Call : Resumes)
  450. if (simplifyTerminatorLeadingToRet(Call->getNextNode())) {
  451. Call->setTailCallKind(CallInst::TCK_MustTail);
  452. changed = true;
  453. }
  454. if (changed)
  455. removeUnreachableBlocks(F);
  456. }
  457. // Coroutine has no suspend points. Remove heap allocation for the coroutine
  458. // frame if possible.
  459. static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) {
  460. auto *CoroId = CoroBegin->getId();
  461. auto *AllocInst = CoroId->getCoroAlloc();
  462. coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr);
  463. if (AllocInst) {
  464. IRBuilder<> Builder(AllocInst);
  465. // FIXME: Need to handle overaligned members.
  466. auto *Frame = Builder.CreateAlloca(FrameTy);
  467. auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
  468. AllocInst->replaceAllUsesWith(Builder.getFalse());
  469. AllocInst->eraseFromParent();
  470. CoroBegin->replaceAllUsesWith(VFrame);
  471. } else {
  472. CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
  473. }
  474. CoroBegin->eraseFromParent();
  475. }
  476. // SimplifySuspendPoint needs to check that there is no calls between
  477. // coro_save and coro_suspend, since any of the calls may potentially resume
  478. // the coroutine and if that is the case we cannot eliminate the suspend point.
  479. static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) {
  480. for (Instruction *I = From; I != To; I = I->getNextNode()) {
  481. // Assume that no intrinsic can resume the coroutine.
  482. if (isa<IntrinsicInst>(I))
  483. continue;
  484. if (CallSite(I))
  485. return true;
  486. }
  487. return false;
  488. }
  489. static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) {
  490. SmallPtrSet<BasicBlock *, 8> Set;
  491. SmallVector<BasicBlock *, 8> Worklist;
  492. Set.insert(SaveBB);
  493. Worklist.push_back(ResDesBB);
  494. // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
  495. // returns a token consumed by suspend instruction, all blocks in between
  496. // will have to eventually hit SaveBB when going backwards from ResDesBB.
  497. while (!Worklist.empty()) {
  498. auto *BB = Worklist.pop_back_val();
  499. Set.insert(BB);
  500. for (auto *Pred : predecessors(BB))
  501. if (Set.count(Pred) == 0)
  502. Worklist.push_back(Pred);
  503. }
  504. // SaveBB and ResDesBB are checked separately in hasCallsBetween.
  505. Set.erase(SaveBB);
  506. Set.erase(ResDesBB);
  507. for (auto *BB : Set)
  508. if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr))
  509. return true;
  510. return false;
  511. }
  512. static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) {
  513. auto *SaveBB = Save->getParent();
  514. auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent();
  515. if (SaveBB == ResumeOrDestroyBB)
  516. return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy);
  517. // Any calls from Save to the end of the block?
  518. if (hasCallsInBlockBetween(Save->getNextNode(), nullptr))
  519. return true;
  520. // Any calls from begging of the block up to ResumeOrDestroy?
  521. if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(),
  522. ResumeOrDestroy))
  523. return true;
  524. // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
  525. if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB))
  526. return true;
  527. return false;
  528. }
  529. // If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
  530. // suspend point and replace it with nornal control flow.
  531. static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
  532. CoroBeginInst *CoroBegin) {
  533. Instruction *Prev = Suspend->getPrevNode();
  534. if (!Prev) {
  535. auto *Pred = Suspend->getParent()->getSinglePredecessor();
  536. if (!Pred)
  537. return false;
  538. Prev = Pred->getTerminator();
  539. }
  540. CallSite CS{Prev};
  541. if (!CS)
  542. return false;
  543. auto *CallInstr = CS.getInstruction();
  544. auto *Callee = CS.getCalledValue()->stripPointerCasts();
  545. // See if the callsite is for resumption or destruction of the coroutine.
  546. auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
  547. if (!SubFn)
  548. return false;
  549. // Does not refer to the current coroutine, we cannot do anything with it.
  550. if (SubFn->getFrame() != CoroBegin)
  551. return false;
  552. // See if the transformation is safe. Specifically, see if there are any
  553. // calls in between Save and CallInstr. They can potenitally resume the
  554. // coroutine rendering this optimization unsafe.
  555. auto *Save = Suspend->getCoroSave();
  556. if (hasCallsBetween(Save, CallInstr))
  557. return false;
  558. // Replace llvm.coro.suspend with the value that results in resumption over
  559. // the resume or cleanup path.
  560. Suspend->replaceAllUsesWith(SubFn->getRawIndex());
  561. Suspend->eraseFromParent();
  562. Save->eraseFromParent();
  563. // No longer need a call to coro.resume or coro.destroy.
  564. if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) {
  565. BranchInst::Create(Invoke->getNormalDest(), Invoke);
  566. }
  567. // Grab the CalledValue from CS before erasing the CallInstr.
  568. auto *CalledValue = CS.getCalledValue();
  569. CallInstr->eraseFromParent();
  570. // If no more users remove it. Usually it is a bitcast of SubFn.
  571. if (CalledValue != SubFn && CalledValue->user_empty())
  572. if (auto *I = dyn_cast<Instruction>(CalledValue))
  573. I->eraseFromParent();
  574. // Now we are good to remove SubFn.
  575. if (SubFn->user_empty())
  576. SubFn->eraseFromParent();
  577. return true;
  578. }
  579. // Remove suspend points that are simplified.
  580. static void simplifySuspendPoints(coro::Shape &Shape) {
  581. auto &S = Shape.CoroSuspends;
  582. size_t I = 0, N = S.size();
  583. if (N == 0)
  584. return;
  585. while (true) {
  586. if (simplifySuspendPoint(S[I], Shape.CoroBegin)) {
  587. if (--N == I)
  588. break;
  589. std::swap(S[I], S[N]);
  590. continue;
  591. }
  592. if (++I == N)
  593. break;
  594. }
  595. S.resize(N);
  596. }
  597. static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) {
  598. // Collect all blocks that we need to look for instructions to relocate.
  599. SmallPtrSet<BasicBlock *, 4> RelocBlocks;
  600. SmallVector<BasicBlock *, 4> Work;
  601. Work.push_back(CB->getParent());
  602. do {
  603. BasicBlock *Current = Work.pop_back_val();
  604. for (BasicBlock *BB : predecessors(Current))
  605. if (RelocBlocks.count(BB) == 0) {
  606. RelocBlocks.insert(BB);
  607. Work.push_back(BB);
  608. }
  609. } while (!Work.empty());
  610. return RelocBlocks;
  611. }
  612. static SmallPtrSet<Instruction *, 8>
  613. getNotRelocatableInstructions(CoroBeginInst *CoroBegin,
  614. SmallPtrSetImpl<BasicBlock *> &RelocBlocks) {
  615. SmallPtrSet<Instruction *, 8> DoNotRelocate;
  616. // Collect all instructions that we should not relocate
  617. SmallVector<Instruction *, 8> Work;
  618. // Start with CoroBegin and terminators of all preceding blocks.
  619. Work.push_back(CoroBegin);
  620. BasicBlock *CoroBeginBB = CoroBegin->getParent();
  621. for (BasicBlock *BB : RelocBlocks)
  622. if (BB != CoroBeginBB)
  623. Work.push_back(BB->getTerminator());
  624. // For every instruction in the Work list, place its operands in DoNotRelocate
  625. // set.
  626. do {
  627. Instruction *Current = Work.pop_back_val();
  628. LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n");
  629. DoNotRelocate.insert(Current);
  630. for (Value *U : Current->operands()) {
  631. auto *I = dyn_cast<Instruction>(U);
  632. if (!I)
  633. continue;
  634. if (auto *A = dyn_cast<AllocaInst>(I)) {
  635. // Stores to alloca instructions that occur before the coroutine frame
  636. // is allocated should not be moved; the stored values may be used by
  637. // the coroutine frame allocator. The operands to those stores must also
  638. // remain in place.
  639. for (const auto &User : A->users())
  640. if (auto *SI = dyn_cast<llvm::StoreInst>(User))
  641. if (RelocBlocks.count(SI->getParent()) != 0 &&
  642. DoNotRelocate.count(SI) == 0) {
  643. Work.push_back(SI);
  644. DoNotRelocate.insert(SI);
  645. }
  646. continue;
  647. }
  648. if (DoNotRelocate.count(I) == 0) {
  649. Work.push_back(I);
  650. DoNotRelocate.insert(I);
  651. }
  652. }
  653. } while (!Work.empty());
  654. return DoNotRelocate;
  655. }
  656. static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) {
  657. // Analyze which non-alloca instructions are needed for allocation and
  658. // relocate the rest to after coro.begin. We need to do it, since some of the
  659. // targets of those instructions may be placed into coroutine frame memory
  660. // for which becomes available after coro.begin intrinsic.
  661. auto BlockSet = getCoroBeginPredBlocks(CoroBegin);
  662. auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet);
  663. Instruction *InsertPt = CoroBegin->getNextNode();
  664. BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well.
  665. for (auto B = BB.begin(), E = BB.end(); B != E;) {
  666. Instruction &I = *B++;
  667. if (isa<AllocaInst>(&I))
  668. continue;
  669. if (&I == CoroBegin)
  670. break;
  671. if (DoNotRelocateSet.count(&I))
  672. continue;
  673. I.moveBefore(InsertPt);
  674. }
  675. }
  676. static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
  677. EliminateUnreachableBlocks(F);
  678. coro::Shape Shape(F);
  679. if (!Shape.CoroBegin)
  680. return;
  681. simplifySuspendPoints(Shape);
  682. relocateInstructionBefore(Shape.CoroBegin, F);
  683. buildCoroutineFrame(F, Shape);
  684. replaceFrameSize(Shape);
  685. // If there are no suspend points, no split required, just remove
  686. // the allocation and deallocation blocks, they are not needed.
  687. if (Shape.CoroSuspends.empty()) {
  688. handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy);
  689. removeCoroEnds(Shape);
  690. postSplitCleanup(F);
  691. coro::updateCallGraph(F, {}, CG, SCC);
  692. return;
  693. }
  694. auto *ResumeEntry = createResumeEntryBlock(F, Shape);
  695. auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
  696. auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
  697. auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2);
  698. // We no longer need coro.end in F.
  699. removeCoroEnds(Shape);
  700. postSplitCleanup(F);
  701. postSplitCleanup(*ResumeClone);
  702. postSplitCleanup(*DestroyClone);
  703. postSplitCleanup(*CleanupClone);
  704. addMustTailToCoroResumes(*ResumeClone);
  705. // Store addresses resume/destroy/cleanup functions in the coroutine frame.
  706. updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
  707. // Create a constant array referring to resume/destroy/clone functions pointed
  708. // by the last argument of @llvm.coro.info, so that CoroElide pass can
  709. // determined correct function to call.
  710. setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone});
  711. // Update call graph and add the functions we created to the SCC.
  712. coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC);
  713. }
  714. // When we see the coroutine the first time, we insert an indirect call to a
  715. // devirt trigger function and mark the coroutine that it is now ready for
  716. // split.
  717. static void prepareForSplit(Function &F, CallGraph &CG) {
  718. Module &M = *F.getParent();
  719. LLVMContext &Context = F.getContext();
  720. #ifndef NDEBUG
  721. Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
  722. assert(DevirtFn && "coro.devirt.trigger function not found");
  723. #endif
  724. F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
  725. // Insert an indirect call sequence that will be devirtualized by CoroElide
  726. // pass:
  727. // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
  728. // %1 = bitcast i8* %0 to void(i8*)*
  729. // call void %1(i8* null)
  730. coro::LowererBase Lowerer(M);
  731. Instruction *InsertPt = F.getEntryBlock().getTerminator();
  732. auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context));
  733. auto *DevirtFnAddr =
  734. Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
  735. FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context),
  736. {Type::getInt8PtrTy(Context)}, false);
  737. auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt);
  738. // Update CG graph with an indirect call we just added.
  739. CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
  740. }
  741. // Make sure that there is a devirtualization trigger function that CoroSplit
  742. // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
  743. // found, we will create one and add it to the current SCC.
  744. static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
  745. Module &M = CG.getModule();
  746. if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
  747. return;
  748. LLVMContext &C = M.getContext();
  749. auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
  750. /*IsVarArgs=*/false);
  751. Function *DevirtFn =
  752. Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
  753. CORO_DEVIRT_TRIGGER_FN, &M);
  754. DevirtFn->addFnAttr(Attribute::AlwaysInline);
  755. auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
  756. ReturnInst::Create(C, Entry);
  757. auto *Node = CG.getOrInsertFunction(DevirtFn);
  758. SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
  759. Nodes.push_back(Node);
  760. SCC.initialize(Nodes);
  761. }
  762. //===----------------------------------------------------------------------===//
  763. // Top Level Driver
  764. //===----------------------------------------------------------------------===//
  765. namespace {
  766. struct CoroSplit : public CallGraphSCCPass {
  767. static char ID; // Pass identification, replacement for typeid
  768. CoroSplit() : CallGraphSCCPass(ID) {
  769. initializeCoroSplitPass(*PassRegistry::getPassRegistry());
  770. }
  771. bool Run = false;
  772. // A coroutine is identified by the presence of coro.begin intrinsic, if
  773. // we don't have any, this pass has nothing to do.
  774. bool doInitialization(CallGraph &CG) override {
  775. Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"});
  776. return CallGraphSCCPass::doInitialization(CG);
  777. }
  778. bool runOnSCC(CallGraphSCC &SCC) override {
  779. if (!Run)
  780. return false;
  781. // Find coroutines for processing.
  782. SmallVector<Function *, 4> Coroutines;
  783. for (CallGraphNode *CGN : SCC)
  784. if (auto *F = CGN->getFunction())
  785. if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
  786. Coroutines.push_back(F);
  787. if (Coroutines.empty())
  788. return false;
  789. CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
  790. createDevirtTriggerFunc(CG, SCC);
  791. for (Function *F : Coroutines) {
  792. Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
  793. StringRef Value = Attr.getValueAsString();
  794. LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
  795. << "' state: " << Value << "\n");
  796. if (Value == UNPREPARED_FOR_SPLIT) {
  797. prepareForSplit(*F, CG);
  798. continue;
  799. }
  800. F->removeFnAttr(CORO_PRESPLIT_ATTR);
  801. splitCoroutine(*F, CG, SCC);
  802. }
  803. return true;
  804. }
  805. void getAnalysisUsage(AnalysisUsage &AU) const override {
  806. CallGraphSCCPass::getAnalysisUsage(AU);
  807. }
  808. StringRef getPassName() const override { return "Coroutine Splitting"; }
  809. };
  810. } // end anonymous namespace
  811. char CoroSplit::ID = 0;
  812. INITIALIZE_PASS_BEGIN(
  813. CoroSplit, "coro-split",
  814. "Split coroutine into a set of functions driving its state machine", false,
  815. false)
  816. INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
  817. INITIALIZE_PASS_END(
  818. CoroSplit, "coro-split",
  819. "Split coroutine into a set of functions driving its state machine", false,
  820. false)
  821. Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }