StructRetPromotion.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. //===-- StructRetPromotion.cpp - Promote sret arguments ------------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This pass finds functions that return a struct (using a pointer to the struct
  11. // as the first argument of the function, marked with the 'sret' attribute) and
  12. // replaces them with a new function that simply returns each of the elements of
  13. // that struct (using multiple return values).
  14. //
  15. // This pass works under a number of conditions:
  16. // 1. The returned struct must not contain other structs
  17. // 2. The returned struct must only be used to load values from
  18. // 3. The placeholder struct passed in is the result of an alloca
  19. //
  20. //===----------------------------------------------------------------------===//
  21. #define DEBUG_TYPE "sretpromotion"
  22. #include "llvm/Transforms/IPO.h"
  23. #include "llvm/Constants.h"
  24. #include "llvm/DerivedTypes.h"
  25. #include "llvm/LLVMContext.h"
  26. #include "llvm/Module.h"
  27. #include "llvm/CallGraphSCCPass.h"
  28. #include "llvm/Instructions.h"
  29. #include "llvm/Analysis/CallGraph.h"
  30. #include "llvm/Support/CallSite.h"
  31. #include "llvm/Support/CFG.h"
  32. #include "llvm/Support/Debug.h"
  33. #include "llvm/ADT/Statistic.h"
  34. #include "llvm/ADT/SmallVector.h"
  35. #include "llvm/ADT/Statistic.h"
  36. #include "llvm/Support/raw_ostream.h"
  37. using namespace llvm;
  38. STATISTIC(NumRejectedSRETUses , "Number of sret rejected due to unexpected uses");
  39. STATISTIC(NumSRET , "Number of sret promoted");
  40. namespace {
  41. /// SRETPromotion - This pass removes sret parameter and updates
  42. /// function to use multiple return value.
  43. ///
  44. struct SRETPromotion : public CallGraphSCCPass {
  45. virtual void getAnalysisUsage(AnalysisUsage &AU) const {
  46. CallGraphSCCPass::getAnalysisUsage(AU);
  47. }
  48. virtual bool runOnSCC(std::vector<CallGraphNode *> &SCC);
  49. static char ID; // Pass identification, replacement for typeid
  50. SRETPromotion() : CallGraphSCCPass(&ID) {}
  51. private:
  52. CallGraphNode *PromoteReturn(CallGraphNode *CGN);
  53. bool isSafeToUpdateAllCallers(Function *F);
  54. Function *cloneFunctionBody(Function *F, const StructType *STy);
  55. CallGraphNode *updateCallSites(Function *F, Function *NF);
  56. bool nestedStructType(const StructType *STy);
  57. };
  58. }
  59. char SRETPromotion::ID = 0;
  60. static RegisterPass<SRETPromotion>
  61. X("sretpromotion", "Promote sret arguments to multiple ret values");
  62. Pass *llvm::createStructRetPromotionPass() {
  63. return new SRETPromotion();
  64. }
  65. bool SRETPromotion::runOnSCC(std::vector<CallGraphNode *> &SCC) {
  66. bool Changed = false;
  67. for (unsigned i = 0, e = SCC.size(); i != e; ++i)
  68. if (CallGraphNode *NewNode = PromoteReturn(SCC[i])) {
  69. SCC[i] = NewNode;
  70. Changed = true;
  71. }
  72. return Changed;
  73. }
  74. /// PromoteReturn - This method promotes function that uses StructRet paramater
  75. /// into a function that uses multiple return values.
  76. CallGraphNode *SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
  77. Function *F = CGN->getFunction();
  78. if (!F || F->isDeclaration() || !F->hasLocalLinkage())
  79. return 0;
  80. // Make sure that function returns struct.
  81. if (F->arg_size() == 0 || !F->hasStructRetAttr() || F->doesNotReturn())
  82. return 0;
  83. DEBUG(errs() << "SretPromotion: Looking at sret function "
  84. << F->getName() << "\n");
  85. assert(F->getReturnType() == Type::getVoidTy(F->getContext()) &&
  86. "Invalid function return type");
  87. Function::arg_iterator AI = F->arg_begin();
  88. const llvm::PointerType *FArgType = dyn_cast<PointerType>(AI->getType());
  89. assert(FArgType && "Invalid sret parameter type");
  90. const llvm::StructType *STy =
  91. dyn_cast<StructType>(FArgType->getElementType());
  92. assert(STy && "Invalid sret parameter element type");
  93. // Check if it is ok to perform this promotion.
  94. if (isSafeToUpdateAllCallers(F) == false) {
  95. DEBUG(errs() << "SretPromotion: Not all callers can be updated\n");
  96. NumRejectedSRETUses++;
  97. return 0;
  98. }
  99. DEBUG(errs() << "SretPromotion: sret argument will be promoted\n");
  100. NumSRET++;
  101. // [1] Replace use of sret parameter
  102. AllocaInst *TheAlloca = new AllocaInst(STy, NULL, "mrv",
  103. F->getEntryBlock().begin());
  104. Value *NFirstArg = F->arg_begin();
  105. NFirstArg->replaceAllUsesWith(TheAlloca);
  106. // [2] Find and replace ret instructions
  107. for (Function::iterator FI = F->begin(), FE = F->end(); FI != FE; ++FI)
  108. for(BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ) {
  109. Instruction *I = BI;
  110. ++BI;
  111. if (isa<ReturnInst>(I)) {
  112. Value *NV = new LoadInst(TheAlloca, "mrv.ld", I);
  113. ReturnInst *NR = ReturnInst::Create(F->getContext(), NV, I);
  114. I->replaceAllUsesWith(NR);
  115. I->eraseFromParent();
  116. }
  117. }
  118. // [3] Create the new function body and insert it into the module.
  119. Function *NF = cloneFunctionBody(F, STy);
  120. // [4] Update all call sites to use new function
  121. CallGraphNode *NF_CFN = updateCallSites(F, NF);
  122. CallGraph &CG = getAnalysis<CallGraph>();
  123. NF_CFN->stealCalledFunctionsFrom(CG[F]);
  124. delete CG.removeFunctionFromModule(F);
  125. return NF_CFN;
  126. }
  127. // Check if it is ok to perform this promotion.
  128. bool SRETPromotion::isSafeToUpdateAllCallers(Function *F) {
  129. if (F->use_empty())
  130. // No users. OK to modify signature.
  131. return true;
  132. for (Value::use_iterator FnUseI = F->use_begin(), FnUseE = F->use_end();
  133. FnUseI != FnUseE; ++FnUseI) {
  134. // The function is passed in as an argument to (possibly) another function,
  135. // we can't change it!
  136. CallSite CS = CallSite::get(*FnUseI);
  137. Instruction *Call = CS.getInstruction();
  138. // The function is used by something else than a call or invoke instruction,
  139. // we can't change it!
  140. if (!Call || !CS.isCallee(FnUseI))
  141. return false;
  142. CallSite::arg_iterator AI = CS.arg_begin();
  143. Value *FirstArg = *AI;
  144. if (!isa<AllocaInst>(FirstArg))
  145. return false;
  146. // Check FirstArg's users.
  147. for (Value::use_iterator ArgI = FirstArg->use_begin(),
  148. ArgE = FirstArg->use_end(); ArgI != ArgE; ++ArgI) {
  149. // If FirstArg user is a CallInst that does not correspond to current
  150. // call site then this function F is not suitable for sret promotion.
  151. if (CallInst *CI = dyn_cast<CallInst>(ArgI)) {
  152. if (CI != Call)
  153. return false;
  154. }
  155. // If FirstArg user is a GEP whose all users are not LoadInst then
  156. // this function F is not suitable for sret promotion.
  157. else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(ArgI)) {
  158. // TODO : Use dom info and insert PHINodes to collect get results
  159. // from multiple call sites for this GEP.
  160. if (GEP->getParent() != Call->getParent())
  161. return false;
  162. for (Value::use_iterator GEPI = GEP->use_begin(), GEPE = GEP->use_end();
  163. GEPI != GEPE; ++GEPI)
  164. if (!isa<LoadInst>(GEPI))
  165. return false;
  166. }
  167. // Any other FirstArg users make this function unsuitable for sret
  168. // promotion.
  169. else
  170. return false;
  171. }
  172. }
  173. return true;
  174. }
  175. /// cloneFunctionBody - Create a new function based on F and
  176. /// insert it into module. Remove first argument. Use STy as
  177. /// the return type for new function.
  178. Function *SRETPromotion::cloneFunctionBody(Function *F,
  179. const StructType *STy) {
  180. const FunctionType *FTy = F->getFunctionType();
  181. std::vector<const Type*> Params;
  182. // Attributes - Keep track of the parameter attributes for the arguments.
  183. SmallVector<AttributeWithIndex, 8> AttributesVec;
  184. const AttrListPtr &PAL = F->getAttributes();
  185. // Add any return attributes.
  186. if (Attributes attrs = PAL.getRetAttributes())
  187. AttributesVec.push_back(AttributeWithIndex::get(0, attrs));
  188. // Skip first argument.
  189. Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
  190. ++I;
  191. // 0th parameter attribute is reserved for return type.
  192. // 1th parameter attribute is for first 1st sret argument.
  193. unsigned ParamIndex = 2;
  194. while (I != E) {
  195. Params.push_back(I->getType());
  196. if (Attributes Attrs = PAL.getParamAttributes(ParamIndex))
  197. AttributesVec.push_back(AttributeWithIndex::get(ParamIndex - 1, Attrs));
  198. ++I;
  199. ++ParamIndex;
  200. }
  201. // Add any fn attributes.
  202. if (Attributes attrs = PAL.getFnAttributes())
  203. AttributesVec.push_back(AttributeWithIndex::get(~0, attrs));
  204. FunctionType *NFTy = FunctionType::get(STy, Params, FTy->isVarArg());
  205. Function *NF = Function::Create(NFTy, F->getLinkage());
  206. NF->takeName(F);
  207. NF->copyAttributesFrom(F);
  208. NF->setAttributes(AttrListPtr::get(AttributesVec.begin(), AttributesVec.end()));
  209. F->getParent()->getFunctionList().insert(F, NF);
  210. NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
  211. // Replace arguments
  212. I = F->arg_begin();
  213. E = F->arg_end();
  214. Function::arg_iterator NI = NF->arg_begin();
  215. ++I;
  216. while (I != E) {
  217. I->replaceAllUsesWith(NI);
  218. NI->takeName(I);
  219. ++I;
  220. ++NI;
  221. }
  222. return NF;
  223. }
  224. /// updateCallSites - Update all sites that call F to use NF.
  225. CallGraphNode *SRETPromotion::updateCallSites(Function *F, Function *NF) {
  226. CallGraph &CG = getAnalysis<CallGraph>();
  227. SmallVector<Value*, 16> Args;
  228. // Attributes - Keep track of the parameter attributes for the arguments.
  229. SmallVector<AttributeWithIndex, 8> ArgAttrsVec;
  230. // Get a new callgraph node for NF.
  231. CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF);
  232. while (!F->use_empty()) {
  233. CallSite CS = CallSite::get(*F->use_begin());
  234. Instruction *Call = CS.getInstruction();
  235. const AttrListPtr &PAL = F->getAttributes();
  236. // Add any return attributes.
  237. if (Attributes attrs = PAL.getRetAttributes())
  238. ArgAttrsVec.push_back(AttributeWithIndex::get(0, attrs));
  239. // Copy arguments, however skip first one.
  240. CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end();
  241. Value *FirstCArg = *AI;
  242. ++AI;
  243. // 0th parameter attribute is reserved for return type.
  244. // 1th parameter attribute is for first 1st sret argument.
  245. unsigned ParamIndex = 2;
  246. while (AI != AE) {
  247. Args.push_back(*AI);
  248. if (Attributes Attrs = PAL.getParamAttributes(ParamIndex))
  249. ArgAttrsVec.push_back(AttributeWithIndex::get(ParamIndex - 1, Attrs));
  250. ++ParamIndex;
  251. ++AI;
  252. }
  253. // Add any function attributes.
  254. if (Attributes attrs = PAL.getFnAttributes())
  255. ArgAttrsVec.push_back(AttributeWithIndex::get(~0, attrs));
  256. AttrListPtr NewPAL = AttrListPtr::get(ArgAttrsVec.begin(), ArgAttrsVec.end());
  257. // Build new call instruction.
  258. Instruction *New;
  259. if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
  260. New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
  261. Args.begin(), Args.end(), "", Call);
  262. cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv());
  263. cast<InvokeInst>(New)->setAttributes(NewPAL);
  264. } else {
  265. New = CallInst::Create(NF, Args.begin(), Args.end(), "", Call);
  266. cast<CallInst>(New)->setCallingConv(CS.getCallingConv());
  267. cast<CallInst>(New)->setAttributes(NewPAL);
  268. if (cast<CallInst>(Call)->isTailCall())
  269. cast<CallInst>(New)->setTailCall();
  270. }
  271. Args.clear();
  272. ArgAttrsVec.clear();
  273. New->takeName(Call);
  274. // Update the callgraph to know that the callsite has been transformed.
  275. CallGraphNode *CalleeNode = CG[Call->getParent()->getParent()];
  276. CalleeNode->removeCallEdgeFor(Call);
  277. CalleeNode->addCalledFunction(New, NF_CGN);
  278. // Update all users of sret parameter to extract value using extractvalue.
  279. for (Value::use_iterator UI = FirstCArg->use_begin(),
  280. UE = FirstCArg->use_end(); UI != UE; ) {
  281. User *U2 = *UI++;
  282. CallInst *C2 = dyn_cast<CallInst>(U2);
  283. if (C2 && (C2 == Call))
  284. continue;
  285. GetElementPtrInst *UGEP = cast<GetElementPtrInst>(U2);
  286. ConstantInt *Idx = cast<ConstantInt>(UGEP->getOperand(2));
  287. Value *GR = ExtractValueInst::Create(New, Idx->getZExtValue(),
  288. "evi", UGEP);
  289. while(!UGEP->use_empty()) {
  290. // isSafeToUpdateAllCallers has checked that all GEP uses are
  291. // LoadInsts
  292. LoadInst *L = cast<LoadInst>(*UGEP->use_begin());
  293. L->replaceAllUsesWith(GR);
  294. L->eraseFromParent();
  295. }
  296. UGEP->eraseFromParent();
  297. continue;
  298. }
  299. Call->eraseFromParent();
  300. }
  301. return NF_CGN;
  302. }
  303. /// nestedStructType - Return true if STy includes any
  304. /// other aggregate types
  305. bool SRETPromotion::nestedStructType(const StructType *STy) {
  306. unsigned Num = STy->getNumElements();
  307. for (unsigned i = 0; i < Num; i++) {
  308. const Type *Ty = STy->getElementType(i);
  309. if (!Ty->isSingleValueType() && Ty != Type::getVoidTy(STy->getContext()))
  310. return true;
  311. }
  312. return false;
  313. }