LoopUnrolling.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. ///
  9. /// This file contains functions which are used to decide if a loop worth to be
  10. /// unrolled. Moreover, these functions manages the stack of loop which is
  11. /// tracked by the ProgramState.
  12. ///
  13. //===----------------------------------------------------------------------===//
  14. #include "clang/ASTMatchers/ASTMatchers.h"
  15. #include "clang/ASTMatchers/ASTMatchFinder.h"
  16. #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
  17. #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
  18. #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
  19. using namespace clang;
  20. using namespace ento;
  21. using namespace clang::ast_matchers;
  22. static const int MAXIMUM_STEP_UNROLLED = 128;
  23. struct LoopState {
  24. private:
  25. enum Kind { Normal, Unrolled } K;
  26. const Stmt *LoopStmt;
  27. const LocationContext *LCtx;
  28. unsigned maxStep;
  29. LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
  30. : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
  31. public:
  32. static LoopState getNormal(const Stmt *S, const LocationContext *L,
  33. unsigned N) {
  34. return LoopState(Normal, S, L, N);
  35. }
  36. static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
  37. unsigned N) {
  38. return LoopState(Unrolled, S, L, N);
  39. }
  40. bool isUnrolled() const { return K == Unrolled; }
  41. unsigned getMaxStep() const { return maxStep; }
  42. const Stmt *getLoopStmt() const { return LoopStmt; }
  43. const LocationContext *getLocationContext() const { return LCtx; }
  44. bool operator==(const LoopState &X) const {
  45. return K == X.K && LoopStmt == X.LoopStmt;
  46. }
  47. void Profile(llvm::FoldingSetNodeID &ID) const {
  48. ID.AddInteger(K);
  49. ID.AddPointer(LoopStmt);
  50. ID.AddPointer(LCtx);
  51. ID.AddInteger(maxStep);
  52. }
  53. };
  54. // The tracked stack of loops. The stack indicates that which loops the
  55. // simulated element contained by. The loops are marked depending if we decided
  56. // to unroll them.
  57. // TODO: The loop stack should not need to be in the program state since it is
  58. // lexical in nature. Instead, the stack of loops should be tracked in the
  59. // LocationContext.
  60. REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
  61. namespace clang {
  62. namespace ento {
  63. static bool isLoopStmt(const Stmt *S) {
  64. return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
  65. }
  66. ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
  67. auto LS = State->get<LoopStack>();
  68. if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
  69. State = State->set<LoopStack>(LS.getTail());
  70. return State;
  71. }
  72. static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
  73. return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"),
  74. hasOperatorName("<="), hasOperatorName(">="),
  75. hasOperatorName("!=")),
  76. hasEitherOperand(ignoringParenImpCasts(declRefExpr(
  77. to(varDecl(hasType(isInteger())).bind(BindName))))),
  78. hasEitherOperand(ignoringParenImpCasts(
  79. integerLiteral().bind("boundNum"))))
  80. .bind("conditionOperator");
  81. }
  82. static internal::Matcher<Stmt>
  83. changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
  84. return anyOf(
  85. unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
  86. hasUnaryOperand(ignoringParenImpCasts(
  87. declRefExpr(to(varDecl(VarNodeMatcher)))))),
  88. binaryOperator(isAssignmentOperator(),
  89. hasLHS(ignoringParenImpCasts(
  90. declRefExpr(to(varDecl(VarNodeMatcher)))))));
  91. }
  92. static internal::Matcher<Stmt>
  93. callByRef(internal::Matcher<Decl> VarNodeMatcher) {
  94. return callExpr(forEachArgumentWithParam(
  95. declRefExpr(to(varDecl(VarNodeMatcher))),
  96. parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
  97. }
  98. static internal::Matcher<Stmt>
  99. assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
  100. return declStmt(hasDescendant(varDecl(
  101. allOf(hasType(referenceType()),
  102. hasInitializer(anyOf(
  103. initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
  104. declRefExpr(to(varDecl(VarNodeMatcher)))))))));
  105. }
  106. static internal::Matcher<Stmt>
  107. getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
  108. return unaryOperator(
  109. hasOperatorName("&"),
  110. hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
  111. }
  112. static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
  113. return hasDescendant(stmt(
  114. anyOf(gotoStmt(), switchStmt(), returnStmt(),
  115. // Escaping and not known mutation of the loop counter is handled
  116. // by exclusion of assigning and address-of operators and
  117. // pass-by-ref function calls on the loop counter from the body.
  118. changeIntBoundNode(equalsBoundNode(NodeName)),
  119. callByRef(equalsBoundNode(NodeName)),
  120. getAddrTo(equalsBoundNode(NodeName)),
  121. assignedToRef(equalsBoundNode(NodeName)))));
  122. }
  123. static internal::Matcher<Stmt> forLoopMatcher() {
  124. return forStmt(
  125. hasCondition(simpleCondition("initVarName")),
  126. // Initialization should match the form: 'int i = 6' or 'i = 42'.
  127. hasLoopInit(
  128. anyOf(declStmt(hasSingleDecl(
  129. varDecl(allOf(hasInitializer(ignoringParenImpCasts(
  130. integerLiteral().bind("initNum"))),
  131. equalsBoundNode("initVarName"))))),
  132. binaryOperator(hasLHS(declRefExpr(to(varDecl(
  133. equalsBoundNode("initVarName"))))),
  134. hasRHS(ignoringParenImpCasts(
  135. integerLiteral().bind("initNum")))))),
  136. // Incrementation should be a simple increment or decrement
  137. // operator call.
  138. hasIncrement(unaryOperator(
  139. anyOf(hasOperatorName("++"), hasOperatorName("--")),
  140. hasUnaryOperand(declRefExpr(
  141. to(varDecl(allOf(equalsBoundNode("initVarName"),
  142. hasType(isInteger())))))))),
  143. unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
  144. }
  145. static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) {
  146. // Global variables assumed as escaped variables.
  147. if (VD->hasGlobalStorage())
  148. return true;
  149. while (!N->pred_empty()) {
  150. // FIXME: getStmtForDiagnostics() does nasty things in order to provide
  151. // a valid statement for body farms, do we need this behavior here?
  152. const Stmt *S = N->getStmtForDiagnostics();
  153. if (!S) {
  154. N = N->getFirstPred();
  155. continue;
  156. }
  157. if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
  158. for (const Decl *D : DS->decls()) {
  159. // Once we reach the declaration of the VD we can return.
  160. if (D->getCanonicalDecl() == VD)
  161. return false;
  162. }
  163. }
  164. // Check the usage of the pass-by-ref function calls and adress-of operator
  165. // on VD and reference initialized by VD.
  166. ASTContext &ASTCtx =
  167. N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
  168. auto Match =
  169. match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
  170. assignedToRef(equalsNode(VD)))),
  171. *S, ASTCtx);
  172. if (!Match.empty())
  173. return true;
  174. N = N->getFirstPred();
  175. }
  176. llvm_unreachable("Reached root without finding the declaration of VD");
  177. }
  178. bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
  179. ExplodedNode *Pred, unsigned &maxStep) {
  180. if (!isLoopStmt(LoopStmt))
  181. return false;
  182. // TODO: Match the cases where the bound is not a concrete literal but an
  183. // integer with known value
  184. auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
  185. if (Matches.empty())
  186. return false;
  187. auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
  188. llvm::APInt BoundNum =
  189. Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
  190. llvm::APInt InitNum =
  191. Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
  192. auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
  193. if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
  194. InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth());
  195. BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth());
  196. }
  197. if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
  198. maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
  199. else
  200. maxStep = (BoundNum - InitNum).abs().getZExtValue();
  201. // Check if the counter of the loop is not escaped before.
  202. return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
  203. }
  204. bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
  205. const Stmt *S = nullptr;
  206. while (!N->pred_empty()) {
  207. if (N->succ_size() > 1)
  208. return true;
  209. ProgramPoint P = N->getLocation();
  210. if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
  211. S = BE->getBlock()->getTerminatorStmt();
  212. if (S == LoopStmt)
  213. return false;
  214. N = N->getFirstPred();
  215. }
  216. llvm_unreachable("Reached root without encountering the previous step");
  217. }
  218. // updateLoopStack is called on every basic block, therefore it needs to be fast
  219. ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
  220. ExplodedNode *Pred, unsigned maxVisitOnPath) {
  221. auto State = Pred->getState();
  222. auto LCtx = Pred->getLocationContext();
  223. if (!isLoopStmt(LoopStmt))
  224. return State;
  225. auto LS = State->get<LoopStack>();
  226. if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
  227. LCtx == LS.getHead().getLocationContext()) {
  228. if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
  229. State = State->set<LoopStack>(LS.getTail());
  230. State = State->add<LoopStack>(
  231. LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
  232. }
  233. return State;
  234. }
  235. unsigned maxStep;
  236. if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
  237. State = State->add<LoopStack>(
  238. LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
  239. return State;
  240. }
  241. unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
  242. unsigned innerMaxStep = maxStep * outerStep;
  243. if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
  244. State = State->add<LoopStack>(
  245. LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
  246. else
  247. State = State->add<LoopStack>(
  248. LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
  249. return State;
  250. }
  251. bool isUnrolledState(ProgramStateRef State) {
  252. auto LS = State->get<LoopStack>();
  253. if (LS.isEmpty() || !LS.getHead().isUnrolled())
  254. return false;
  255. return true;
  256. }
  257. }
  258. }