PointerTracking.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. //===- PointerTracking.cpp - Pointer Bounds Tracking ------------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This file implements tracking of pointer bounds.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "llvm/Analysis/ConstantFolding.h"
  14. #include "llvm/Analysis/Dominators.h"
  15. #include "llvm/Analysis/LoopInfo.h"
  16. #include "llvm/Analysis/PointerTracking.h"
  17. #include "llvm/Analysis/ScalarEvolution.h"
  18. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  19. #include "llvm/Constants.h"
  20. #include "llvm/Module.h"
  21. #include "llvm/Value.h"
  22. #include "llvm/Support/CallSite.h"
  23. #include "llvm/Support/InstIterator.h"
  24. #include "llvm/Support/raw_ostream.h"
  25. #include "llvm/Target/TargetData.h"
  26. namespace llvm {
  27. char PointerTracking::ID=0;
  28. PointerTracking::PointerTracking() : FunctionPass(&ID) {}
  29. bool PointerTracking::runOnFunction(Function &F) {
  30. predCache.clear();
  31. assert(analyzing.empty());
  32. FF = &F;
  33. TD = getAnalysisIfAvailable<TargetData>();
  34. SE = &getAnalysis<ScalarEvolution>();
  35. LI = &getAnalysis<LoopInfo>();
  36. DT = &getAnalysis<DominatorTree>();
  37. return false;
  38. }
  39. void PointerTracking::getAnalysisUsage(AnalysisUsage &AU) const {
  40. AU.addRequiredTransitive<DominatorTree>();
  41. AU.addRequiredTransitive<LoopInfo>();
  42. AU.addRequiredTransitive<ScalarEvolution>();
  43. AU.setPreservesAll();
  44. }
  45. bool PointerTracking::doInitialization(Module &M) {
  46. const Type *PTy = PointerType::getUnqual(Type::getInt8Ty(M.getContext()));
  47. // Find calloc(i64, i64) or calloc(i32, i32).
  48. callocFunc = M.getFunction("calloc");
  49. if (callocFunc) {
  50. const FunctionType *Ty = callocFunc->getFunctionType();
  51. std::vector<const Type*> args, args2;
  52. args.push_back(Type::getInt64Ty(M.getContext()));
  53. args.push_back(Type::getInt64Ty(M.getContext()));
  54. args2.push_back(Type::getInt32Ty(M.getContext()));
  55. args2.push_back(Type::getInt32Ty(M.getContext()));
  56. const FunctionType *Calloc1Type =
  57. FunctionType::get(PTy, args, false);
  58. const FunctionType *Calloc2Type =
  59. FunctionType::get(PTy, args2, false);
  60. if (Ty != Calloc1Type && Ty != Calloc2Type)
  61. callocFunc = 0; // Give up
  62. }
  63. // Find realloc(i8*, i64) or realloc(i8*, i32).
  64. reallocFunc = M.getFunction("realloc");
  65. if (reallocFunc) {
  66. const FunctionType *Ty = reallocFunc->getFunctionType();
  67. std::vector<const Type*> args, args2;
  68. args.push_back(PTy);
  69. args.push_back(Type::getInt64Ty(M.getContext()));
  70. args2.push_back(PTy);
  71. args2.push_back(Type::getInt32Ty(M.getContext()));
  72. const FunctionType *Realloc1Type =
  73. FunctionType::get(PTy, args, false);
  74. const FunctionType *Realloc2Type =
  75. FunctionType::get(PTy, args2, false);
  76. if (Ty != Realloc1Type && Ty != Realloc2Type)
  77. reallocFunc = 0; // Give up
  78. }
  79. return false;
  80. }
  81. // Calculates the number of elements allocated for pointer P,
  82. // the type of the element is stored in Ty.
  83. const SCEV *PointerTracking::computeAllocationCount(Value *P,
  84. const Type *&Ty) const {
  85. Value *V = P->stripPointerCasts();
  86. if (AllocationInst *AI = dyn_cast<AllocationInst>(V)) {
  87. Value *arraySize = AI->getArraySize();
  88. Ty = AI->getAllocatedType();
  89. // arraySize elements of type Ty.
  90. return SE->getSCEV(arraySize);
  91. }
  92. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
  93. if (GV->hasDefinitiveInitializer()) {
  94. Constant *C = GV->getInitializer();
  95. if (const ArrayType *ATy = dyn_cast<ArrayType>(C->getType())) {
  96. Ty = ATy->getElementType();
  97. return SE->getConstant(Type::getInt32Ty(Ty->getContext()),
  98. ATy->getNumElements());
  99. }
  100. }
  101. Ty = GV->getType();
  102. return SE->getConstant(Type::getInt32Ty(Ty->getContext()), 1);
  103. //TODO: implement more tracking for globals
  104. }
  105. if (CallInst *CI = dyn_cast<CallInst>(V)) {
  106. CallSite CS(CI);
  107. Function *F = dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts());
  108. const Loop *L = LI->getLoopFor(CI->getParent());
  109. if (F == callocFunc) {
  110. Ty = Type::getInt8Ty(Ty->getContext());
  111. // calloc allocates arg0*arg1 bytes.
  112. return SE->getSCEVAtScope(SE->getMulExpr(SE->getSCEV(CS.getArgument(0)),
  113. SE->getSCEV(CS.getArgument(1))),
  114. L);
  115. } else if (F == reallocFunc) {
  116. Ty = Type::getInt8Ty(Ty->getContext());
  117. // realloc allocates arg1 bytes.
  118. return SE->getSCEVAtScope(CS.getArgument(1), L);
  119. }
  120. }
  121. return SE->getCouldNotCompute();
  122. }
  123. // Calculates the number of elements of type Ty allocated for P.
  124. const SCEV *PointerTracking::computeAllocationCountForType(Value *P,
  125. const Type *Ty)
  126. const {
  127. const Type *elementTy;
  128. const SCEV *Count = computeAllocationCount(P, elementTy);
  129. if (isa<SCEVCouldNotCompute>(Count))
  130. return Count;
  131. if (elementTy == Ty)
  132. return Count;
  133. if (!TD) // need TargetData from this point forward
  134. return SE->getCouldNotCompute();
  135. uint64_t elementSize = TD->getTypeAllocSize(elementTy);
  136. uint64_t wantSize = TD->getTypeAllocSize(Ty);
  137. if (elementSize == wantSize)
  138. return Count;
  139. if (elementSize % wantSize) //fractional counts not possible
  140. return SE->getCouldNotCompute();
  141. return SE->getMulExpr(Count, SE->getConstant(Count->getType(),
  142. elementSize/wantSize));
  143. }
  144. const SCEV *PointerTracking::getAllocationElementCount(Value *V) const {
  145. // We only deal with pointers.
  146. const PointerType *PTy = cast<PointerType>(V->getType());
  147. return computeAllocationCountForType(V, PTy->getElementType());
  148. }
  149. const SCEV *PointerTracking::getAllocationSizeInBytes(Value *V) const {
  150. return computeAllocationCountForType(V, Type::getInt8Ty(V->getContext()));
  151. }
  152. // Helper for isLoopGuardedBy that checks the swapped and inverted predicate too
  153. enum SolverResult PointerTracking::isLoopGuardedBy(const Loop *L,
  154. Predicate Pred,
  155. const SCEV *A,
  156. const SCEV *B) const {
  157. if (SE->isLoopGuardedByCond(L, Pred, A, B))
  158. return AlwaysTrue;
  159. Pred = ICmpInst::getSwappedPredicate(Pred);
  160. if (SE->isLoopGuardedByCond(L, Pred, B, A))
  161. return AlwaysTrue;
  162. Pred = ICmpInst::getInversePredicate(Pred);
  163. if (SE->isLoopGuardedByCond(L, Pred, B, A))
  164. return AlwaysFalse;
  165. Pred = ICmpInst::getSwappedPredicate(Pred);
  166. if (SE->isLoopGuardedByCond(L, Pred, A, B))
  167. return AlwaysTrue;
  168. return Unknown;
  169. }
  170. enum SolverResult PointerTracking::checkLimits(const SCEV *Offset,
  171. const SCEV *Limit,
  172. BasicBlock *BB)
  173. {
  174. //FIXME: merge implementation
  175. return Unknown;
  176. }
  177. void PointerTracking::getPointerOffset(Value *Pointer, Value *&Base,
  178. const SCEV *&Limit,
  179. const SCEV *&Offset) const
  180. {
  181. Pointer = Pointer->stripPointerCasts();
  182. Base = Pointer->getUnderlyingObject();
  183. Limit = getAllocationSizeInBytes(Base);
  184. if (isa<SCEVCouldNotCompute>(Limit)) {
  185. Base = 0;
  186. Offset = Limit;
  187. return;
  188. }
  189. Offset = SE->getMinusSCEV(SE->getSCEV(Pointer), SE->getSCEV(Base));
  190. if (isa<SCEVCouldNotCompute>(Offset)) {
  191. Base = 0;
  192. Limit = Offset;
  193. }
  194. }
  195. void PointerTracking::print(raw_ostream &OS, const Module* M) const {
  196. // Calling some PT methods may cause caches to be updated, however
  197. // this should be safe for the same reason its safe for SCEV.
  198. PointerTracking &PT = *const_cast<PointerTracking*>(this);
  199. for (inst_iterator I=inst_begin(*FF), E=inst_end(*FF); I != E; ++I) {
  200. if (!isa<PointerType>(I->getType()))
  201. continue;
  202. Value *Base;
  203. const SCEV *Limit, *Offset;
  204. getPointerOffset(&*I, Base, Limit, Offset);
  205. if (!Base)
  206. continue;
  207. if (Base == &*I) {
  208. const SCEV *S = getAllocationElementCount(Base);
  209. OS << *Base << " ==> " << *S << " elements, ";
  210. OS << *Limit << " bytes allocated\n";
  211. continue;
  212. }
  213. OS << &*I << " -- base: " << *Base;
  214. OS << " offset: " << *Offset;
  215. enum SolverResult res = PT.checkLimits(Offset, Limit, I->getParent());
  216. switch (res) {
  217. case AlwaysTrue:
  218. OS << " always safe\n";
  219. break;
  220. case AlwaysFalse:
  221. OS << " always unsafe\n";
  222. break;
  223. case Unknown:
  224. OS << " <<unknown>>\n";
  225. break;
  226. }
  227. }
  228. }
  229. void PointerTracking::print(std::ostream &o, const Module* M) const {
  230. raw_os_ostream OS(o);
  231. print(OS, M);
  232. }
  233. static RegisterPass<PointerTracking> X("pointertracking",
  234. "Track pointer bounds", false, true);
  235. }