CallGraph.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. //===- CallGraph.cpp - AST-based Call graph -------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines the AST-based CallGraph.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "clang/Analysis/CallGraph.h"
  13. #include "clang/AST/Decl.h"
  14. #include "clang/AST/DeclBase.h"
  15. #include "clang/AST/DeclObjC.h"
  16. #include "clang/AST/Expr.h"
  17. #include "clang/AST/ExprObjC.h"
  18. #include "clang/AST/Stmt.h"
  19. #include "clang/AST/StmtVisitor.h"
  20. #include "clang/Basic/IdentifierTable.h"
  21. #include "clang/Basic/LLVM.h"
  22. #include "llvm/ADT/PostOrderIterator.h"
  23. #include "llvm/ADT/STLExtras.h"
  24. #include "llvm/ADT/Statistic.h"
  25. #include "llvm/Support/Casting.h"
  26. #include "llvm/Support/Compiler.h"
  27. #include "llvm/Support/DOTGraphTraits.h"
  28. #include "llvm/Support/GraphWriter.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. #include <cassert>
  31. #include <memory>
  32. #include <string>
  33. using namespace clang;
  34. #define DEBUG_TYPE "CallGraph"
  35. STATISTIC(NumObjCCallEdges, "Number of Objective-C method call edges");
  36. STATISTIC(NumBlockCallEdges, "Number of block call edges");
  37. namespace {
  38. /// A helper class, which walks the AST and locates all the call sites in the
  39. /// given function body.
  40. class CGBuilder : public StmtVisitor<CGBuilder> {
  41. CallGraph *G;
  42. CallGraphNode *CallerNode;
  43. public:
  44. CGBuilder(CallGraph *g, CallGraphNode *N) : G(g), CallerNode(N) {}
  45. void VisitStmt(Stmt *S) { VisitChildren(S); }
  46. Decl *getDeclFromCall(CallExpr *CE) {
  47. if (FunctionDecl *CalleeDecl = CE->getDirectCallee())
  48. return CalleeDecl;
  49. // Simple detection of a call through a block.
  50. Expr *CEE = CE->getCallee()->IgnoreParenImpCasts();
  51. if (BlockExpr *Block = dyn_cast<BlockExpr>(CEE)) {
  52. NumBlockCallEdges++;
  53. return Block->getBlockDecl();
  54. }
  55. return nullptr;
  56. }
  57. void addCalledDecl(Decl *D) {
  58. if (G->includeInGraph(D)) {
  59. CallGraphNode *CalleeNode = G->getOrInsertNode(D);
  60. CallerNode->addCallee(CalleeNode);
  61. }
  62. }
  63. void VisitCallExpr(CallExpr *CE) {
  64. if (Decl *D = getDeclFromCall(CE))
  65. addCalledDecl(D);
  66. VisitChildren(CE);
  67. }
  68. // Adds may-call edges for the ObjC message sends.
  69. void VisitObjCMessageExpr(ObjCMessageExpr *ME) {
  70. if (ObjCInterfaceDecl *IDecl = ME->getReceiverInterface()) {
  71. Selector Sel = ME->getSelector();
  72. // Find the callee definition within the same translation unit.
  73. Decl *D = nullptr;
  74. if (ME->isInstanceMessage())
  75. D = IDecl->lookupPrivateMethod(Sel);
  76. else
  77. D = IDecl->lookupPrivateClassMethod(Sel);
  78. if (D) {
  79. addCalledDecl(D);
  80. NumObjCCallEdges++;
  81. }
  82. }
  83. }
  84. void VisitChildren(Stmt *S) {
  85. for (Stmt *SubStmt : S->children())
  86. if (SubStmt)
  87. this->Visit(SubStmt);
  88. }
  89. };
  90. } // namespace
  91. void CallGraph::addNodesForBlocks(DeclContext *D) {
  92. if (BlockDecl *BD = dyn_cast<BlockDecl>(D))
  93. addNodeForDecl(BD, true);
  94. for (auto *I : D->decls())
  95. if (auto *DC = dyn_cast<DeclContext>(I))
  96. addNodesForBlocks(DC);
  97. }
  98. CallGraph::CallGraph() {
  99. Root = getOrInsertNode(nullptr);
  100. }
  101. CallGraph::~CallGraph() = default;
  102. bool CallGraph::includeInGraph(const Decl *D) {
  103. assert(D);
  104. if (!D->hasBody())
  105. return false;
  106. if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
  107. // We skip function template definitions, as their semantics is
  108. // only determined when they are instantiated.
  109. if (FD->isDependentContext())
  110. return false;
  111. IdentifierInfo *II = FD->getIdentifier();
  112. if (II && II->getName().startswith("__inline"))
  113. return false;
  114. }
  115. return true;
  116. }
  117. void CallGraph::addNodeForDecl(Decl* D, bool IsGlobal) {
  118. assert(D);
  119. // Allocate a new node, mark it as root, and process it's calls.
  120. CallGraphNode *Node = getOrInsertNode(D);
  121. // Process all the calls by this function as well.
  122. CGBuilder builder(this, Node);
  123. if (Stmt *Body = D->getBody())
  124. builder.Visit(Body);
  125. }
  126. CallGraphNode *CallGraph::getNode(const Decl *F) const {
  127. FunctionMapTy::const_iterator I = FunctionMap.find(F);
  128. if (I == FunctionMap.end()) return nullptr;
  129. return I->second.get();
  130. }
  131. CallGraphNode *CallGraph::getOrInsertNode(Decl *F) {
  132. if (F && !isa<ObjCMethodDecl>(F))
  133. F = F->getCanonicalDecl();
  134. std::unique_ptr<CallGraphNode> &Node = FunctionMap[F];
  135. if (Node)
  136. return Node.get();
  137. Node = std::make_unique<CallGraphNode>(F);
  138. // Make Root node a parent of all functions to make sure all are reachable.
  139. if (F)
  140. Root->addCallee(Node.get());
  141. return Node.get();
  142. }
  143. void CallGraph::print(raw_ostream &OS) const {
  144. OS << " --- Call graph Dump --- \n";
  145. // We are going to print the graph in reverse post order, partially, to make
  146. // sure the output is deterministic.
  147. llvm::ReversePostOrderTraversal<const CallGraph *> RPOT(this);
  148. for (llvm::ReversePostOrderTraversal<const CallGraph *>::rpo_iterator
  149. I = RPOT.begin(), E = RPOT.end(); I != E; ++I) {
  150. const CallGraphNode *N = *I;
  151. OS << " Function: ";
  152. if (N == Root)
  153. OS << "< root >";
  154. else
  155. N->print(OS);
  156. OS << " calls: ";
  157. for (CallGraphNode::const_iterator CI = N->begin(),
  158. CE = N->end(); CI != CE; ++CI) {
  159. assert(*CI != Root && "No one can call the root node.");
  160. (*CI)->print(OS);
  161. OS << " ";
  162. }
  163. OS << '\n';
  164. }
  165. OS.flush();
  166. }
  167. LLVM_DUMP_METHOD void CallGraph::dump() const {
  168. print(llvm::errs());
  169. }
  170. void CallGraph::viewGraph() const {
  171. llvm::ViewGraph(this, "CallGraph");
  172. }
  173. void CallGraphNode::print(raw_ostream &os) const {
  174. if (const NamedDecl *ND = dyn_cast_or_null<NamedDecl>(FD))
  175. return ND->printQualifiedName(os);
  176. os << "< >";
  177. }
  178. LLVM_DUMP_METHOD void CallGraphNode::dump() const {
  179. print(llvm::errs());
  180. }
  181. namespace llvm {
  182. template <>
  183. struct DOTGraphTraits<const CallGraph*> : public DefaultDOTGraphTraits {
  184. DOTGraphTraits (bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
  185. static std::string getNodeLabel(const CallGraphNode *Node,
  186. const CallGraph *CG) {
  187. if (CG->getRoot() == Node) {
  188. return "< root >";
  189. }
  190. if (const NamedDecl *ND = dyn_cast_or_null<NamedDecl>(Node->getDecl()))
  191. return ND->getNameAsString();
  192. else
  193. return "< >";
  194. }
  195. };
  196. } // namespace llvm