123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959 |
- //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- // This pass builds the coroutine frame and outlines resume and destroy parts
- // of the coroutine into separate functions.
- //
- // We present a coroutine to an LLVM as an ordinary function with suspension
- // points marked up with intrinsics. We let the optimizer party on the coroutine
- // as a single function for as long as possible. Shortly before the coroutine is
- // eligible to be inlined into its callers, we split up the coroutine into parts
- // corresponding to an initial, resume and destroy invocations of the coroutine,
- // add them to the current SCC and restart the IPO pipeline to optimize the
- // coroutine subfunctions we extracted before proceeding to the caller of the
- // coroutine.
- //===----------------------------------------------------------------------===//
- #include "CoroInstr.h"
- #include "CoroInternal.h"
- #include "llvm/ADT/DenseMap.h"
- #include "llvm/ADT/SmallPtrSet.h"
- #include "llvm/ADT/SmallVector.h"
- #include "llvm/ADT/StringRef.h"
- #include "llvm/ADT/Twine.h"
- #include "llvm/Analysis/CallGraph.h"
- #include "llvm/Analysis/CallGraphSCCPass.h"
- #include "llvm/Transforms/Utils/Local.h"
- #include "llvm/IR/Argument.h"
- #include "llvm/IR/Attributes.h"
- #include "llvm/IR/BasicBlock.h"
- #include "llvm/IR/CFG.h"
- #include "llvm/IR/CallSite.h"
- #include "llvm/IR/CallingConv.h"
- #include "llvm/IR/Constants.h"
- #include "llvm/IR/DataLayout.h"
- #include "llvm/IR/DerivedTypes.h"
- #include "llvm/IR/Function.h"
- #include "llvm/IR/GlobalValue.h"
- #include "llvm/IR/GlobalVariable.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/InstIterator.h"
- #include "llvm/IR/InstrTypes.h"
- #include "llvm/IR/Instruction.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/IR/IntrinsicInst.h"
- #include "llvm/IR/LLVMContext.h"
- #include "llvm/IR/LegacyPassManager.h"
- #include "llvm/IR/Module.h"
- #include "llvm/IR/Type.h"
- #include "llvm/IR/Value.h"
- #include "llvm/IR/Verifier.h"
- #include "llvm/Pass.h"
- #include "llvm/Support/Casting.h"
- #include "llvm/Support/Debug.h"
- #include "llvm/Support/raw_ostream.h"
- #include "llvm/Transforms/Scalar.h"
- #include "llvm/Transforms/Utils/BasicBlockUtils.h"
- #include "llvm/Transforms/Utils/Cloning.h"
- #include "llvm/Transforms/Utils/ValueMapper.h"
- #include <cassert>
- #include <cstddef>
- #include <cstdint>
- #include <initializer_list>
- #include <iterator>
- using namespace llvm;
- #define DEBUG_TYPE "coro-split"
- // Create an entry block for a resume function with a switch that will jump to
- // suspend points.
- static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
- LLVMContext &C = F.getContext();
- // resume.entry:
- // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
- // i32 2
- // % index = load i32, i32* %index.addr
- // switch i32 %index, label %unreachable [
- // i32 0, label %resume.0
- // i32 1, label %resume.1
- // ...
- // ]
- auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
- auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
- IRBuilder<> Builder(NewEntry);
- auto *FramePtr = Shape.FramePtr;
- auto *FrameTy = Shape.FrameTy;
- auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
- FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
- auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
- auto *Switch =
- Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
- Shape.ResumeSwitch = Switch;
- size_t SuspendIndex = 0;
- for (CoroSuspendInst *S : Shape.CoroSuspends) {
- ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
- // Replace CoroSave with a store to Index:
- // %index.addr = getelementptr %f.frame... (index field number)
- // store i32 0, i32* %index.addr1
- auto *Save = S->getCoroSave();
- Builder.SetInsertPoint(Save);
- if (S->isFinal()) {
- // Final suspend point is represented by storing zero in ResumeFnAddr.
- auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0,
- 0, "ResumeFn.addr");
- auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
- cast<PointerType>(GepIndex->getType())->getElementType()));
- Builder.CreateStore(NullPtr, GepIndex);
- } else {
- auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
- FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
- Builder.CreateStore(IndexVal, GepIndex);
- }
- Save->replaceAllUsesWith(ConstantTokenNone::get(C));
- Save->eraseFromParent();
- // Split block before and after coro.suspend and add a jump from an entry
- // switch:
- //
- // whateverBB:
- // whatever
- // %0 = call i8 @llvm.coro.suspend(token none, i1 false)
- // switch i8 %0, label %suspend[i8 0, label %resume
- // i8 1, label %cleanup]
- // becomes:
- //
- // whateverBB:
- // whatever
- // br label %resume.0.landing
- //
- // resume.0: ; <--- jump from the switch in the resume.entry
- // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
- // br label %resume.0.landing
- //
- // resume.0.landing:
- // %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
- // switch i8 % 1, label %suspend [i8 0, label %resume
- // i8 1, label %cleanup]
- auto *SuspendBB = S->getParent();
- auto *ResumeBB =
- SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
- auto *LandingBB = ResumeBB->splitBasicBlock(
- S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
- Switch->addCase(IndexVal, ResumeBB);
- cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
- auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
- S->replaceAllUsesWith(PN);
- PN->addIncoming(Builder.getInt8(-1), SuspendBB);
- PN->addIncoming(S, ResumeBB);
- ++SuspendIndex;
- }
- Builder.SetInsertPoint(UnreachBB);
- Builder.CreateUnreachable();
- return NewEntry;
- }
- // In Resumers, we replace fallthrough coro.end with ret void and delete the
- // rest of the block.
- static void replaceFallthroughCoroEnd(IntrinsicInst *End,
- ValueToValueMapTy &VMap) {
- auto *NewE = cast<IntrinsicInst>(VMap[End]);
- ReturnInst::Create(NewE->getContext(), nullptr, NewE);
- // Remove the rest of the block, by splitting it into an unreachable block.
- auto *BB = NewE->getParent();
- BB->splitBasicBlock(NewE);
- BB->getTerminator()->eraseFromParent();
- }
- // In Resumers, we replace unwind coro.end with True to force the immediate
- // unwind to caller.
- static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
- if (Shape.CoroEnds.empty())
- return;
- LLVMContext &Context = Shape.CoroEnds.front()->getContext();
- auto *True = ConstantInt::getTrue(Context);
- for (CoroEndInst *CE : Shape.CoroEnds) {
- if (!CE->isUnwind())
- continue;
- auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
- // If coro.end has an associated bundle, add cleanupret instruction.
- if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
- Value *FromPad = Bundle->Inputs[0];
- auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
- NewCE->getParent()->splitBasicBlock(NewCE);
- CleanupRet->getParent()->getTerminator()->eraseFromParent();
- }
- NewCE->replaceAllUsesWith(True);
- NewCE->eraseFromParent();
- }
- }
- // Rewrite final suspend point handling. We do not use suspend index to
- // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
- // coroutine frame, since it is undefined behavior to resume a coroutine
- // suspended at the final suspend point. Thus, in the resume function, we can
- // simply remove the last case (when coro::Shape is built, the final suspend
- // point (if present) is always the last element of CoroSuspends array).
- // In the destroy function, we add a code sequence to check if ResumeFnAddress
- // is Null, and if so, jump to the appropriate label to handle cleanup from the
- // final suspend point.
- static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
- coro::Shape &Shape, SwitchInst *Switch,
- bool IsDestroy) {
- assert(Shape.HasFinalSuspend);
- auto FinalCaseIt = std::prev(Switch->case_end());
- BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
- Switch->removeCase(FinalCaseIt);
- if (IsDestroy) {
- BasicBlock *OldSwitchBB = Switch->getParent();
- auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
- Builder.SetInsertPoint(OldSwitchBB->getTerminator());
- auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr,
- 0, 0, "ResumeFn.addr");
- auto *Load = Builder.CreateLoad(
- Shape.FrameTy->getElementType(coro::Shape::ResumeField), GepIndex);
- auto *NullPtr =
- ConstantPointerNull::get(cast<PointerType>(Load->getType()));
- auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
- Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
- OldSwitchBB->getTerminator()->eraseFromParent();
- }
- }
- // Create a resume clone by cloning the body of the original function, setting
- // new entry block and replacing coro.suspend an appropriate value to force
- // resume or cleanup pass for every suspend point.
- static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
- BasicBlock *ResumeEntry, int8_t FnIndex) {
- Module *M = F.getParent();
- auto *FrameTy = Shape.FrameTy;
- auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
- auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
- Function *NewF =
- Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage,
- F.getName() + Suffix, M);
- NewF->addParamAttr(0, Attribute::NonNull);
- NewF->addParamAttr(0, Attribute::NoAlias);
- ValueToValueMapTy VMap;
- // Replace all args with undefs. The buildCoroutineFrame algorithm already
- // rewritten access to the args that occurs after suspend points with loads
- // and stores to/from the coroutine frame.
- for (Argument &A : F.args())
- VMap[&A] = UndefValue::get(A.getType());
- SmallVector<ReturnInst *, 4> Returns;
- CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
- NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
- // Remove old returns.
- for (ReturnInst *Return : Returns)
- changeToUnreachable(Return, /*UseLLVMTrap=*/false);
- // Remove old return attributes.
- NewF->removeAttributes(
- AttributeList::ReturnIndex,
- AttributeFuncs::typeIncompatible(NewF->getReturnType()));
- // Make AllocaSpillBlock the new entry block.
- auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
- auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
- Entry->moveBefore(&NewF->getEntryBlock());
- Entry->getTerminator()->eraseFromParent();
- BranchInst::Create(SwitchBB, Entry);
- Entry->setName("entry" + Suffix);
- // Clear all predecessors of the new entry block.
- auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
- Entry->replaceAllUsesWith(Switch->getDefaultDest());
- IRBuilder<> Builder(&NewF->getEntryBlock().front());
- // Remap frame pointer.
- Argument *NewFramePtr = &*NewF->arg_begin();
- Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
- NewFramePtr->takeName(OldFramePtr);
- OldFramePtr->replaceAllUsesWith(NewFramePtr);
- // Remap vFrame pointer.
- auto *NewVFrame = Builder.CreateBitCast(
- NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
- Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
- OldVFrame->replaceAllUsesWith(NewVFrame);
- // Rewrite final suspend handling as it is not done via switch (allows to
- // remove final case from the switch, since it is undefined behavior to resume
- // the coroutine suspended at the final suspend point.
- if (Shape.HasFinalSuspend) {
- auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
- bool IsDestroy = FnIndex != 0;
- handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy);
- }
- // Replace coro suspend with the appropriate resume index.
- // Replacing coro.suspend with (0) will result in control flow proceeding to
- // a resume label associated with a suspend point, replacing it with (1) will
- // result in control flow proceeding to a cleanup label associated with this
- // suspend point.
- auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0);
- for (CoroSuspendInst *CS : Shape.CoroSuspends) {
- auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
- MappedCS->replaceAllUsesWith(NewValue);
- MappedCS->eraseFromParent();
- }
- // Remove coro.end intrinsics.
- replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
- replaceUnwindCoroEnds(Shape, VMap);
- // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
- // to suppress deallocation code.
- coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
- /*Elide=*/FnIndex == 2);
- NewF->setCallingConv(CallingConv::Fast);
- return NewF;
- }
- static void removeCoroEnds(coro::Shape &Shape) {
- if (Shape.CoroEnds.empty())
- return;
- LLVMContext &Context = Shape.CoroEnds.front()->getContext();
- auto *False = ConstantInt::getFalse(Context);
- for (CoroEndInst *CE : Shape.CoroEnds) {
- CE->replaceAllUsesWith(False);
- CE->eraseFromParent();
- }
- }
- static void replaceFrameSize(coro::Shape &Shape) {
- if (Shape.CoroSizes.empty())
- return;
- // In the same function all coro.sizes should have the same result type.
- auto *SizeIntrin = Shape.CoroSizes.back();
- Module *M = SizeIntrin->getModule();
- const DataLayout &DL = M->getDataLayout();
- auto Size = DL.getTypeAllocSize(Shape.FrameTy);
- auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
- for (CoroSizeInst *CS : Shape.CoroSizes) {
- CS->replaceAllUsesWith(SizeConstant);
- CS->eraseFromParent();
- }
- }
- // Create a global constant array containing pointers to functions provided and
- // set Info parameter of CoroBegin to point at this constant. Example:
- //
- // @f.resumers = internal constant [2 x void(%f.frame*)*]
- // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
- // define void @f() {
- // ...
- // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
- // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
- //
- // Assumes that all the functions have the same signature.
- static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
- std::initializer_list<Function *> Fns) {
- SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
- assert(!Args.empty());
- Function *Part = *Fns.begin();
- Module *M = Part->getParent();
- auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
- auto *ConstVal = ConstantArray::get(ArrTy, Args);
- auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
- GlobalVariable::PrivateLinkage, ConstVal,
- F.getName() + Twine(".resumers"));
- // Update coro.begin instruction to refer to this constant.
- LLVMContext &C = F.getContext();
- auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
- CoroBegin->getId()->setInfo(BC);
- }
- // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
- static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
- Function *DestroyFn, Function *CleanupFn) {
- IRBuilder<> Builder(Shape.FramePtr->getNextNode());
- auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32(
- Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField,
- "resume.addr");
- Builder.CreateStore(ResumeFn, ResumeAddr);
- Value *DestroyOrCleanupFn = DestroyFn;
- CoroIdInst *CoroId = Shape.CoroBegin->getId();
- if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
- // If there is a CoroAlloc and it returns false (meaning we elide the
- // allocation, use CleanupFn instead of DestroyFn).
- DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
- }
- auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32(
- Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField,
- "destroy.addr");
- Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
- }
- static void postSplitCleanup(Function &F) {
- removeUnreachableBlocks(F);
- legacy::FunctionPassManager FPM(F.getParent());
- FPM.add(createVerifierPass());
- FPM.add(createSCCPPass());
- FPM.add(createCFGSimplificationPass());
- FPM.add(createEarlyCSEPass());
- FPM.add(createCFGSimplificationPass());
- FPM.doInitialization();
- FPM.run(F);
- FPM.doFinalization();
- }
- // Assuming we arrived at the block NewBlock from Prev instruction, store
- // PHI's incoming values in the ResolvedValues map.
- static void
- scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
- DenseMap<Value *, Value *> &ResolvedValues) {
- auto *PrevBB = Prev->getParent();
- for (PHINode &PN : NewBlock->phis()) {
- auto V = PN.getIncomingValueForBlock(PrevBB);
- // See if we already resolved it.
- auto VI = ResolvedValues.find(V);
- if (VI != ResolvedValues.end())
- V = VI->second;
- // Remember the value.
- ResolvedValues[&PN] = V;
- }
- }
- // Replace a sequence of branches leading to a ret, with a clone of a ret
- // instruction. Suspend instruction represented by a switch, track the PHI
- // values and select the correct case successor when possible.
- static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
- DenseMap<Value *, Value *> ResolvedValues;
- Instruction *I = InitialInst;
- while (I->isTerminator()) {
- if (isa<ReturnInst>(I)) {
- if (I != InitialInst)
- ReplaceInstWithInst(InitialInst, I->clone());
- return true;
- }
- if (auto *BR = dyn_cast<BranchInst>(I)) {
- if (BR->isUnconditional()) {
- BasicBlock *BB = BR->getSuccessor(0);
- scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
- I = BB->getFirstNonPHIOrDbgOrLifetime();
- continue;
- }
- } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
- Value *V = SI->getCondition();
- auto it = ResolvedValues.find(V);
- if (it != ResolvedValues.end())
- V = it->second;
- if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
- BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
- scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
- I = BB->getFirstNonPHIOrDbgOrLifetime();
- continue;
- }
- }
- return false;
- }
- return false;
- }
- // Add musttail to any resume instructions that is immediately followed by a
- // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
- // for symmetrical coroutine control transfer (C++ Coroutines TS extension).
- // This transformation is done only in the resume part of the coroutine that has
- // identical signature and calling convention as the coro.resume call.
- static void addMustTailToCoroResumes(Function &F) {
- bool changed = false;
- // Collect potential resume instructions.
- SmallVector<CallInst *, 4> Resumes;
- for (auto &I : instructions(F))
- if (auto *Call = dyn_cast<CallInst>(&I))
- if (auto *CalledValue = Call->getCalledValue())
- // CoroEarly pass replaced coro resumes with indirect calls to an
- // address return by CoroSubFnInst intrinsic. See if it is one of those.
- if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts()))
- Resumes.push_back(Call);
- // Set musttail on those that are followed by a ret instruction.
- for (CallInst *Call : Resumes)
- if (simplifyTerminatorLeadingToRet(Call->getNextNode())) {
- Call->setTailCallKind(CallInst::TCK_MustTail);
- changed = true;
- }
- if (changed)
- removeUnreachableBlocks(F);
- }
- // Coroutine has no suspend points. Remove heap allocation for the coroutine
- // frame if possible.
- static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) {
- auto *CoroId = CoroBegin->getId();
- auto *AllocInst = CoroId->getCoroAlloc();
- coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr);
- if (AllocInst) {
- IRBuilder<> Builder(AllocInst);
- // FIXME: Need to handle overaligned members.
- auto *Frame = Builder.CreateAlloca(FrameTy);
- auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
- AllocInst->replaceAllUsesWith(Builder.getFalse());
- AllocInst->eraseFromParent();
- CoroBegin->replaceAllUsesWith(VFrame);
- } else {
- CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
- }
- CoroBegin->eraseFromParent();
- }
- // SimplifySuspendPoint needs to check that there is no calls between
- // coro_save and coro_suspend, since any of the calls may potentially resume
- // the coroutine and if that is the case we cannot eliminate the suspend point.
- static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) {
- for (Instruction *I = From; I != To; I = I->getNextNode()) {
- // Assume that no intrinsic can resume the coroutine.
- if (isa<IntrinsicInst>(I))
- continue;
- if (CallSite(I))
- return true;
- }
- return false;
- }
- static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) {
- SmallPtrSet<BasicBlock *, 8> Set;
- SmallVector<BasicBlock *, 8> Worklist;
- Set.insert(SaveBB);
- Worklist.push_back(ResDesBB);
- // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
- // returns a token consumed by suspend instruction, all blocks in between
- // will have to eventually hit SaveBB when going backwards from ResDesBB.
- while (!Worklist.empty()) {
- auto *BB = Worklist.pop_back_val();
- Set.insert(BB);
- for (auto *Pred : predecessors(BB))
- if (Set.count(Pred) == 0)
- Worklist.push_back(Pred);
- }
- // SaveBB and ResDesBB are checked separately in hasCallsBetween.
- Set.erase(SaveBB);
- Set.erase(ResDesBB);
- for (auto *BB : Set)
- if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr))
- return true;
- return false;
- }
- static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) {
- auto *SaveBB = Save->getParent();
- auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent();
- if (SaveBB == ResumeOrDestroyBB)
- return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy);
- // Any calls from Save to the end of the block?
- if (hasCallsInBlockBetween(Save->getNextNode(), nullptr))
- return true;
- // Any calls from begging of the block up to ResumeOrDestroy?
- if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(),
- ResumeOrDestroy))
- return true;
- // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
- if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB))
- return true;
- return false;
- }
- // If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
- // suspend point and replace it with nornal control flow.
- static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
- CoroBeginInst *CoroBegin) {
- Instruction *Prev = Suspend->getPrevNode();
- if (!Prev) {
- auto *Pred = Suspend->getParent()->getSinglePredecessor();
- if (!Pred)
- return false;
- Prev = Pred->getTerminator();
- }
- CallSite CS{Prev};
- if (!CS)
- return false;
- auto *CallInstr = CS.getInstruction();
- auto *Callee = CS.getCalledValue()->stripPointerCasts();
- // See if the callsite is for resumption or destruction of the coroutine.
- auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
- if (!SubFn)
- return false;
- // Does not refer to the current coroutine, we cannot do anything with it.
- if (SubFn->getFrame() != CoroBegin)
- return false;
- // See if the transformation is safe. Specifically, see if there are any
- // calls in between Save and CallInstr. They can potenitally resume the
- // coroutine rendering this optimization unsafe.
- auto *Save = Suspend->getCoroSave();
- if (hasCallsBetween(Save, CallInstr))
- return false;
- // Replace llvm.coro.suspend with the value that results in resumption over
- // the resume or cleanup path.
- Suspend->replaceAllUsesWith(SubFn->getRawIndex());
- Suspend->eraseFromParent();
- Save->eraseFromParent();
- // No longer need a call to coro.resume or coro.destroy.
- if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) {
- BranchInst::Create(Invoke->getNormalDest(), Invoke);
- }
- // Grab the CalledValue from CS before erasing the CallInstr.
- auto *CalledValue = CS.getCalledValue();
- CallInstr->eraseFromParent();
- // If no more users remove it. Usually it is a bitcast of SubFn.
- if (CalledValue != SubFn && CalledValue->user_empty())
- if (auto *I = dyn_cast<Instruction>(CalledValue))
- I->eraseFromParent();
- // Now we are good to remove SubFn.
- if (SubFn->user_empty())
- SubFn->eraseFromParent();
- return true;
- }
- // Remove suspend points that are simplified.
- static void simplifySuspendPoints(coro::Shape &Shape) {
- auto &S = Shape.CoroSuspends;
- size_t I = 0, N = S.size();
- if (N == 0)
- return;
- while (true) {
- if (simplifySuspendPoint(S[I], Shape.CoroBegin)) {
- if (--N == I)
- break;
- std::swap(S[I], S[N]);
- continue;
- }
- if (++I == N)
- break;
- }
- S.resize(N);
- }
- static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) {
- // Collect all blocks that we need to look for instructions to relocate.
- SmallPtrSet<BasicBlock *, 4> RelocBlocks;
- SmallVector<BasicBlock *, 4> Work;
- Work.push_back(CB->getParent());
- do {
- BasicBlock *Current = Work.pop_back_val();
- for (BasicBlock *BB : predecessors(Current))
- if (RelocBlocks.count(BB) == 0) {
- RelocBlocks.insert(BB);
- Work.push_back(BB);
- }
- } while (!Work.empty());
- return RelocBlocks;
- }
- static SmallPtrSet<Instruction *, 8>
- getNotRelocatableInstructions(CoroBeginInst *CoroBegin,
- SmallPtrSetImpl<BasicBlock *> &RelocBlocks) {
- SmallPtrSet<Instruction *, 8> DoNotRelocate;
- // Collect all instructions that we should not relocate
- SmallVector<Instruction *, 8> Work;
- // Start with CoroBegin and terminators of all preceding blocks.
- Work.push_back(CoroBegin);
- BasicBlock *CoroBeginBB = CoroBegin->getParent();
- for (BasicBlock *BB : RelocBlocks)
- if (BB != CoroBeginBB)
- Work.push_back(BB->getTerminator());
- // For every instruction in the Work list, place its operands in DoNotRelocate
- // set.
- do {
- Instruction *Current = Work.pop_back_val();
- LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n");
- DoNotRelocate.insert(Current);
- for (Value *U : Current->operands()) {
- auto *I = dyn_cast<Instruction>(U);
- if (!I)
- continue;
- if (auto *A = dyn_cast<AllocaInst>(I)) {
- // Stores to alloca instructions that occur before the coroutine frame
- // is allocated should not be moved; the stored values may be used by
- // the coroutine frame allocator. The operands to those stores must also
- // remain in place.
- for (const auto &User : A->users())
- if (auto *SI = dyn_cast<llvm::StoreInst>(User))
- if (RelocBlocks.count(SI->getParent()) != 0 &&
- DoNotRelocate.count(SI) == 0) {
- Work.push_back(SI);
- DoNotRelocate.insert(SI);
- }
- continue;
- }
- if (DoNotRelocate.count(I) == 0) {
- Work.push_back(I);
- DoNotRelocate.insert(I);
- }
- }
- } while (!Work.empty());
- return DoNotRelocate;
- }
- static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) {
- // Analyze which non-alloca instructions are needed for allocation and
- // relocate the rest to after coro.begin. We need to do it, since some of the
- // targets of those instructions may be placed into coroutine frame memory
- // for which becomes available after coro.begin intrinsic.
- auto BlockSet = getCoroBeginPredBlocks(CoroBegin);
- auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet);
- Instruction *InsertPt = CoroBegin->getNextNode();
- BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well.
- for (auto B = BB.begin(), E = BB.end(); B != E;) {
- Instruction &I = *B++;
- if (isa<AllocaInst>(&I))
- continue;
- if (&I == CoroBegin)
- break;
- if (DoNotRelocateSet.count(&I))
- continue;
- I.moveBefore(InsertPt);
- }
- }
- static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
- EliminateUnreachableBlocks(F);
- coro::Shape Shape(F);
- if (!Shape.CoroBegin)
- return;
- simplifySuspendPoints(Shape);
- relocateInstructionBefore(Shape.CoroBegin, F);
- buildCoroutineFrame(F, Shape);
- replaceFrameSize(Shape);
- // If there are no suspend points, no split required, just remove
- // the allocation and deallocation blocks, they are not needed.
- if (Shape.CoroSuspends.empty()) {
- handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy);
- removeCoroEnds(Shape);
- postSplitCleanup(F);
- coro::updateCallGraph(F, {}, CG, SCC);
- return;
- }
- auto *ResumeEntry = createResumeEntryBlock(F, Shape);
- auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
- auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
- auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2);
- // We no longer need coro.end in F.
- removeCoroEnds(Shape);
- postSplitCleanup(F);
- postSplitCleanup(*ResumeClone);
- postSplitCleanup(*DestroyClone);
- postSplitCleanup(*CleanupClone);
- addMustTailToCoroResumes(*ResumeClone);
- // Store addresses resume/destroy/cleanup functions in the coroutine frame.
- updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
- // Create a constant array referring to resume/destroy/clone functions pointed
- // by the last argument of @llvm.coro.info, so that CoroElide pass can
- // determined correct function to call.
- setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone});
- // Update call graph and add the functions we created to the SCC.
- coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC);
- }
- // When we see the coroutine the first time, we insert an indirect call to a
- // devirt trigger function and mark the coroutine that it is now ready for
- // split.
- static void prepareForSplit(Function &F, CallGraph &CG) {
- Module &M = *F.getParent();
- LLVMContext &Context = F.getContext();
- #ifndef NDEBUG
- Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
- assert(DevirtFn && "coro.devirt.trigger function not found");
- #endif
- F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
- // Insert an indirect call sequence that will be devirtualized by CoroElide
- // pass:
- // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
- // %1 = bitcast i8* %0 to void(i8*)*
- // call void %1(i8* null)
- coro::LowererBase Lowerer(M);
- Instruction *InsertPt = F.getEntryBlock().getTerminator();
- auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context));
- auto *DevirtFnAddr =
- Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
- FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context),
- {Type::getInt8PtrTy(Context)}, false);
- auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt);
- // Update CG graph with an indirect call we just added.
- CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
- }
- // Make sure that there is a devirtualization trigger function that CoroSplit
- // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
- // found, we will create one and add it to the current SCC.
- static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
- Module &M = CG.getModule();
- if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
- return;
- LLVMContext &C = M.getContext();
- auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
- /*IsVarArgs=*/false);
- Function *DevirtFn =
- Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
- CORO_DEVIRT_TRIGGER_FN, &M);
- DevirtFn->addFnAttr(Attribute::AlwaysInline);
- auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
- ReturnInst::Create(C, Entry);
- auto *Node = CG.getOrInsertFunction(DevirtFn);
- SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
- Nodes.push_back(Node);
- SCC.initialize(Nodes);
- }
- //===----------------------------------------------------------------------===//
- // Top Level Driver
- //===----------------------------------------------------------------------===//
- namespace {
- struct CoroSplit : public CallGraphSCCPass {
- static char ID; // Pass identification, replacement for typeid
- CoroSplit() : CallGraphSCCPass(ID) {
- initializeCoroSplitPass(*PassRegistry::getPassRegistry());
- }
- bool Run = false;
- // A coroutine is identified by the presence of coro.begin intrinsic, if
- // we don't have any, this pass has nothing to do.
- bool doInitialization(CallGraph &CG) override {
- Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"});
- return CallGraphSCCPass::doInitialization(CG);
- }
- bool runOnSCC(CallGraphSCC &SCC) override {
- if (!Run)
- return false;
- // Find coroutines for processing.
- SmallVector<Function *, 4> Coroutines;
- for (CallGraphNode *CGN : SCC)
- if (auto *F = CGN->getFunction())
- if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
- Coroutines.push_back(F);
- if (Coroutines.empty())
- return false;
- CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
- createDevirtTriggerFunc(CG, SCC);
- for (Function *F : Coroutines) {
- Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
- StringRef Value = Attr.getValueAsString();
- LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
- << "' state: " << Value << "\n");
- if (Value == UNPREPARED_FOR_SPLIT) {
- prepareForSplit(*F, CG);
- continue;
- }
- F->removeFnAttr(CORO_PRESPLIT_ATTR);
- splitCoroutine(*F, CG, SCC);
- }
- return true;
- }
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- CallGraphSCCPass::getAnalysisUsage(AU);
- }
- StringRef getPassName() const override { return "Coroutine Splitting"; }
- };
- } // end anonymous namespace
- char CoroSplit::ID = 0;
- INITIALIZE_PASS_BEGIN(
- CoroSplit, "coro-split",
- "Split coroutine into a set of functions driving its state machine", false,
- false)
- INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
- INITIALIZE_PASS_END(
- CoroSplit, "coro-split",
- "Split coroutine into a set of functions driving its state machine", false,
- false)
- Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }
|