JumpInstrTables.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. //===-- JumpInstrTables.cpp: Jump-Instruction Tables ----------------------===//
  2. //
  3. // This file is distributed under the University of Illinois Open Source
  4. // License. See LICENSE.TXT for details.
  5. //
  6. //===----------------------------------------------------------------------===//
  7. ///
  8. /// \file
  9. /// \brief An implementation of jump-instruction tables.
  10. ///
  11. //===----------------------------------------------------------------------===//
  12. #define DEBUG_TYPE "jt"
  13. #include "llvm/CodeGen/JumpInstrTables.h"
  14. #include "llvm/ADT/Statistic.h"
  15. #include "llvm/Analysis/JumpInstrTableInfo.h"
  16. #include "llvm/CodeGen/Passes.h"
  17. #include "llvm/IR/Attributes.h"
  18. #include "llvm/IR/CallSite.h"
  19. #include "llvm/IR/Constants.h"
  20. #include "llvm/IR/DerivedTypes.h"
  21. #include "llvm/IR/Function.h"
  22. #include "llvm/IR/LLVMContext.h"
  23. #include "llvm/IR/Module.h"
  24. #include "llvm/IR/Operator.h"
  25. #include "llvm/IR/Type.h"
  26. #include "llvm/IR/Verifier.h"
  27. #include "llvm/Support/CommandLine.h"
  28. #include "llvm/Support/Debug.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. #include <vector>
  31. using namespace llvm;
  32. char JumpInstrTables::ID = 0;
  33. INITIALIZE_PASS_BEGIN(JumpInstrTables, "jump-instr-tables",
  34. "Jump-Instruction Tables", true, true)
  35. INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
  36. INITIALIZE_PASS_END(JumpInstrTables, "jump-instr-tables",
  37. "Jump-Instruction Tables", true, true)
  38. STATISTIC(NumJumpTables, "Number of indirect call tables generated");
  39. STATISTIC(NumFuncsInJumpTables, "Number of functions in the jump tables");
  40. ModulePass *llvm::createJumpInstrTablesPass() {
  41. // The default implementation uses a single table for all functions.
  42. return new JumpInstrTables(JumpTable::Single);
  43. }
  44. ModulePass *llvm::createJumpInstrTablesPass(JumpTable::JumpTableType JTT) {
  45. return new JumpInstrTables(JTT);
  46. }
  47. namespace {
  48. static const char jump_func_prefix[] = "__llvm_jump_instr_table_";
  49. static const char jump_section_prefix[] = ".jump.instr.table.text.";
  50. // Checks to see if a given CallSite is making an indirect call, including
  51. // cases where the indirect call is made through a bitcast.
  52. bool isIndirectCall(CallSite &CS) {
  53. if (CS.getCalledFunction())
  54. return false;
  55. // Check the value to see if it is merely a bitcast of a function. In
  56. // this case, it will translate to a direct function call in the resulting
  57. // assembly, so we won't treat it as an indirect call here.
  58. const Value *V = CS.getCalledValue();
  59. if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
  60. return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
  61. }
  62. // Otherwise, since we know it's a call, it must be an indirect call
  63. return true;
  64. }
  65. // Replaces Functions and GlobalAliases with a different Value.
  66. bool replaceGlobalValueIndirectUse(GlobalValue *GV, Value *V, Use *U) {
  67. User *Us = U->getUser();
  68. if (!Us)
  69. return false;
  70. if (Instruction *I = dyn_cast<Instruction>(Us)) {
  71. CallSite CS(I);
  72. // Don't do the replacement if this use is a direct call to this function.
  73. // If the use is not the called value, then replace it.
  74. if (CS && (isIndirectCall(CS) || CS.isCallee(U))) {
  75. return false;
  76. }
  77. U->set(V);
  78. } else if (Constant *C = dyn_cast<Constant>(Us)) {
  79. // Don't replace calls to bitcasts of function symbols, since they get
  80. // translated to direct calls.
  81. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Us)) {
  82. if (CE->getOpcode() == Instruction::BitCast) {
  83. // This bitcast must have exactly one user.
  84. if (CE->user_begin() != CE->user_end()) {
  85. User *ParentUs = *CE->user_begin();
  86. if (CallInst *CI = dyn_cast<CallInst>(ParentUs)) {
  87. CallSite CS(CI);
  88. Use &CEU = *CE->use_begin();
  89. if (CS.isCallee(&CEU)) {
  90. return false;
  91. }
  92. }
  93. }
  94. }
  95. }
  96. // GlobalAlias doesn't support replaceUsesOfWithOnConstant. And the verifier
  97. // requires alias to point to a defined function. So, GlobalAlias is handled
  98. // as a separate case in runOnModule.
  99. if (!isa<GlobalAlias>(C))
  100. C->replaceUsesOfWithOnConstant(GV, V, U);
  101. } else {
  102. llvm_unreachable("The Use of a Function symbol is neither an instruction "
  103. "nor a constant");
  104. }
  105. return true;
  106. }
  107. // Replaces all replaceable address-taken uses of GV with a pointer to a
  108. // jump-instruction table entry.
  109. void replaceValueWithFunction(GlobalValue *GV, Function *F) {
  110. // Go through all uses of this function and replace the uses of GV with the
  111. // jump-table version of the function. Get the uses as a vector before
  112. // replacing them, since replacing them changes the use list and invalidates
  113. // the iterator otherwise.
  114. for (Value::use_iterator I = GV->use_begin(), E = GV->use_end(); I != E;) {
  115. Use &U = *I++;
  116. // Replacement of constants replaces all instances in the constant. So, some
  117. // uses might have already been handled by the time we reach them here.
  118. if (U.get() == GV)
  119. replaceGlobalValueIndirectUse(GV, F, &U);
  120. }
  121. return;
  122. }
  123. } // end anonymous namespace
  124. JumpInstrTables::JumpInstrTables()
  125. : ModulePass(ID), Metadata(), JITI(nullptr), TableCount(0),
  126. JTType(JumpTable::Single) {
  127. initializeJumpInstrTablesPass(*PassRegistry::getPassRegistry());
  128. }
  129. JumpInstrTables::JumpInstrTables(JumpTable::JumpTableType JTT)
  130. : ModulePass(ID), Metadata(), JITI(nullptr), TableCount(0), JTType(JTT) {
  131. initializeJumpInstrTablesPass(*PassRegistry::getPassRegistry());
  132. }
  133. JumpInstrTables::~JumpInstrTables() {}
  134. void JumpInstrTables::getAnalysisUsage(AnalysisUsage &AU) const {
  135. AU.addRequired<JumpInstrTableInfo>();
  136. }
  137. Function *JumpInstrTables::insertEntry(Module &M, Function *Target) {
  138. FunctionType *OrigFunTy = Target->getFunctionType();
  139. FunctionType *FunTy = transformType(JTType, OrigFunTy);
  140. JumpMap::iterator it = Metadata.find(FunTy);
  141. if (Metadata.end() == it) {
  142. struct TableMeta Meta;
  143. Meta.TableNum = TableCount;
  144. Meta.Count = 0;
  145. Metadata[FunTy] = Meta;
  146. it = Metadata.find(FunTy);
  147. ++NumJumpTables;
  148. ++TableCount;
  149. }
  150. it->second.Count++;
  151. std::string NewName(jump_func_prefix);
  152. NewName += (Twine(it->second.TableNum) + "_" + Twine(it->second.Count)).str();
  153. Function *JumpFun =
  154. Function::Create(OrigFunTy, GlobalValue::ExternalLinkage, NewName, &M);
  155. // The section for this table
  156. JumpFun->setSection((jump_section_prefix + Twine(it->second.TableNum)).str());
  157. JITI->insertEntry(FunTy, Target, JumpFun);
  158. ++NumFuncsInJumpTables;
  159. return JumpFun;
  160. }
  161. bool JumpInstrTables::hasTable(FunctionType *FunTy) {
  162. FunctionType *TransTy = transformType(JTType, FunTy);
  163. return Metadata.end() != Metadata.find(TransTy);
  164. }
  165. FunctionType *JumpInstrTables::transformType(JumpTable::JumpTableType JTT,
  166. FunctionType *FunTy) {
  167. // Returning nullptr forces all types into the same table, since all types map
  168. // to the same type
  169. Type *VoidPtrTy = Type::getInt8PtrTy(FunTy->getContext());
  170. // Ignore the return type.
  171. Type *RetTy = VoidPtrTy;
  172. bool IsVarArg = FunTy->isVarArg();
  173. std::vector<Type *> ParamTys(FunTy->getNumParams());
  174. FunctionType::param_iterator PI, PE;
  175. int i = 0;
  176. std::vector<Type *> EmptyParams;
  177. Type *Int32Ty = Type::getInt32Ty(FunTy->getContext());
  178. FunctionType *VoidFnTy = FunctionType::get(
  179. Type::getVoidTy(FunTy->getContext()), EmptyParams, false);
  180. switch (JTT) {
  181. case JumpTable::Single:
  182. return FunctionType::get(RetTy, EmptyParams, false);
  183. case JumpTable::Arity:
  184. // Transform all types to void* so that all functions with the same arity
  185. // end up in the same table.
  186. for (PI = FunTy->param_begin(), PE = FunTy->param_end(); PI != PE;
  187. PI++, i++) {
  188. ParamTys[i] = VoidPtrTy;
  189. }
  190. return FunctionType::get(RetTy, ParamTys, IsVarArg);
  191. case JumpTable::Simplified:
  192. // Project all parameters types to one of 3 types: composite, integer, and
  193. // function, matching the three subclasses of Type.
  194. for (PI = FunTy->param_begin(), PE = FunTy->param_end(); PI != PE;
  195. ++PI, ++i) {
  196. assert((isa<IntegerType>(*PI) || isa<FunctionType>(*PI) ||
  197. isa<CompositeType>(*PI)) &&
  198. "This type is not an Integer or a Composite or a Function");
  199. if (isa<CompositeType>(*PI)) {
  200. ParamTys[i] = VoidPtrTy;
  201. } else if (isa<FunctionType>(*PI)) {
  202. ParamTys[i] = VoidFnTy;
  203. } else if (isa<IntegerType>(*PI)) {
  204. ParamTys[i] = Int32Ty;
  205. }
  206. }
  207. return FunctionType::get(RetTy, ParamTys, IsVarArg);
  208. case JumpTable::Full:
  209. // Don't transform this type at all.
  210. return FunTy;
  211. }
  212. return nullptr;
  213. }
  214. bool JumpInstrTables::runOnModule(Module &M) {
  215. JITI = &getAnalysis<JumpInstrTableInfo>();
  216. // Get the set of jumptable-annotated functions that have their address taken.
  217. DenseMap<Function *, Function *> Functions;
  218. for (Function &F : M) {
  219. if (F.hasFnAttribute(Attribute::JumpTable) && F.hasAddressTaken()) {
  220. assert(F.hasUnnamedAddr() &&
  221. "Attribute 'jumptable' requires 'unnamed_addr'");
  222. Functions[&F] = nullptr;
  223. }
  224. }
  225. // Create the jump-table functions.
  226. for (auto &KV : Functions) {
  227. Function *F = KV.first;
  228. KV.second = insertEntry(M, F);
  229. }
  230. // GlobalAlias is a special case, because the target of an alias statement
  231. // must be a defined function. So, instead of replacing a given function in
  232. // the alias, we replace all uses of aliases that target jumptable functions.
  233. // Note that there's no need to create these functions, since only aliases
  234. // that target known jumptable functions are replaced, and there's no way to
  235. // put the jumptable annotation on a global alias.
  236. DenseMap<GlobalAlias *, Function *> Aliases;
  237. for (GlobalAlias &GA : M.aliases()) {
  238. Constant *Aliasee = GA.getAliasee();
  239. if (Function *F = dyn_cast<Function>(Aliasee)) {
  240. auto it = Functions.find(F);
  241. if (it != Functions.end()) {
  242. Aliases[&GA] = it->second;
  243. }
  244. }
  245. }
  246. // Replace each address taken function with its jump-instruction table entry.
  247. for (auto &KV : Functions)
  248. replaceValueWithFunction(KV.first, KV.second);
  249. for (auto &KV : Aliases)
  250. replaceValueWithFunction(KV.first, KV.second);
  251. return !Functions.empty();
  252. }