CodeGenPGO.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // Instrumentation-based profile-guided optimization
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "CodeGenPGO.h"
  14. #include "CodeGenFunction.h"
  15. #include "CoverageMappingGen.h"
  16. #include "clang/AST/RecursiveASTVisitor.h"
  17. #include "clang/AST/StmtVisitor.h"
  18. #include "llvm/IR/Intrinsics.h"
  19. #include "llvm/IR/MDBuilder.h"
  20. #include "llvm/Support/Endian.h"
  21. #include "llvm/Support/FileSystem.h"
  22. #include "llvm/Support/MD5.h"
  23. static llvm::cl::opt<bool> EnableValueProfiling(
  24. "enable-value-profiling", llvm::cl::ZeroOrMore,
  25. llvm::cl::desc("Enable value profiling"), llvm::cl::init(false));
  26. using namespace clang;
  27. using namespace CodeGen;
  28. void CodeGenPGO::setFuncName(StringRef Name,
  29. llvm::GlobalValue::LinkageTypes Linkage) {
  30. llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
  31. FuncName = llvm::getPGOFuncName(
  32. Name, Linkage, CGM.getCodeGenOpts().MainFileName,
  33. PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
  34. // If we're generating a profile, create a variable for the name.
  35. if (CGM.getCodeGenOpts().hasProfileClangInstr())
  36. FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
  37. }
  38. void CodeGenPGO::setFuncName(llvm::Function *Fn) {
  39. setFuncName(Fn->getName(), Fn->getLinkage());
  40. // Create PGOFuncName meta data.
  41. llvm::createPGOFuncNameMetadata(*Fn, FuncName);
  42. }
  43. namespace {
  44. /// \brief Stable hasher for PGO region counters.
  45. ///
  46. /// PGOHash produces a stable hash of a given function's control flow.
  47. ///
  48. /// Changing the output of this hash will invalidate all previously generated
  49. /// profiles -- i.e., don't do it.
  50. ///
  51. /// \note When this hash does eventually change (years?), we still need to
  52. /// support old hashes. We'll need to pull in the version number from the
  53. /// profile data format and use the matching hash function.
  54. class PGOHash {
  55. uint64_t Working;
  56. unsigned Count;
  57. llvm::MD5 MD5;
  58. static const int NumBitsPerType = 6;
  59. static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
  60. static const unsigned TooBig = 1u << NumBitsPerType;
  61. public:
  62. /// \brief Hash values for AST nodes.
  63. ///
  64. /// Distinct values for AST nodes that have region counters attached.
  65. ///
  66. /// These values must be stable. All new members must be added at the end,
  67. /// and no members should be removed. Changing the enumeration value for an
  68. /// AST node will affect the hash of every function that contains that node.
  69. enum HashType : unsigned char {
  70. None = 0,
  71. LabelStmt = 1,
  72. WhileStmt,
  73. DoStmt,
  74. ForStmt,
  75. CXXForRangeStmt,
  76. ObjCForCollectionStmt,
  77. SwitchStmt,
  78. CaseStmt,
  79. DefaultStmt,
  80. IfStmt,
  81. CXXTryStmt,
  82. CXXCatchStmt,
  83. ConditionalOperator,
  84. BinaryOperatorLAnd,
  85. BinaryOperatorLOr,
  86. BinaryConditionalOperator,
  87. // Keep this last. It's for the static assert that follows.
  88. LastHashType
  89. };
  90. static_assert(LastHashType <= TooBig, "Too many types in HashType");
  91. // TODO: When this format changes, take in a version number here, and use the
  92. // old hash calculation for file formats that used the old hash.
  93. PGOHash() : Working(0), Count(0) {}
  94. void combine(HashType Type);
  95. uint64_t finalize();
  96. };
  97. const int PGOHash::NumBitsPerType;
  98. const unsigned PGOHash::NumTypesPerWord;
  99. const unsigned PGOHash::TooBig;
  100. /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
  101. struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
  102. /// The next counter value to assign.
  103. unsigned NextCounter;
  104. /// The function hash.
  105. PGOHash Hash;
  106. /// The map of statements to counters.
  107. llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
  108. MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
  109. : NextCounter(0), CounterMap(CounterMap) {}
  110. // Blocks and lambdas are handled as separate functions, so we need not
  111. // traverse them in the parent context.
  112. bool TraverseBlockExpr(BlockExpr *BE) { return true; }
  113. bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
  114. bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
  115. bool VisitDecl(const Decl *D) {
  116. switch (D->getKind()) {
  117. default:
  118. break;
  119. case Decl::Function:
  120. case Decl::CXXMethod:
  121. case Decl::CXXConstructor:
  122. case Decl::CXXDestructor:
  123. case Decl::CXXConversion:
  124. case Decl::ObjCMethod:
  125. case Decl::Block:
  126. case Decl::Captured:
  127. CounterMap[D->getBody()] = NextCounter++;
  128. break;
  129. }
  130. return true;
  131. }
  132. bool VisitStmt(const Stmt *S) {
  133. auto Type = getHashType(S);
  134. if (Type == PGOHash::None)
  135. return true;
  136. CounterMap[S] = NextCounter++;
  137. Hash.combine(Type);
  138. return true;
  139. }
  140. PGOHash::HashType getHashType(const Stmt *S) {
  141. switch (S->getStmtClass()) {
  142. default:
  143. break;
  144. case Stmt::LabelStmtClass:
  145. return PGOHash::LabelStmt;
  146. case Stmt::WhileStmtClass:
  147. return PGOHash::WhileStmt;
  148. case Stmt::DoStmtClass:
  149. return PGOHash::DoStmt;
  150. case Stmt::ForStmtClass:
  151. return PGOHash::ForStmt;
  152. case Stmt::CXXForRangeStmtClass:
  153. return PGOHash::CXXForRangeStmt;
  154. case Stmt::ObjCForCollectionStmtClass:
  155. return PGOHash::ObjCForCollectionStmt;
  156. case Stmt::SwitchStmtClass:
  157. return PGOHash::SwitchStmt;
  158. case Stmt::CaseStmtClass:
  159. return PGOHash::CaseStmt;
  160. case Stmt::DefaultStmtClass:
  161. return PGOHash::DefaultStmt;
  162. case Stmt::IfStmtClass:
  163. return PGOHash::IfStmt;
  164. case Stmt::CXXTryStmtClass:
  165. return PGOHash::CXXTryStmt;
  166. case Stmt::CXXCatchStmtClass:
  167. return PGOHash::CXXCatchStmt;
  168. case Stmt::ConditionalOperatorClass:
  169. return PGOHash::ConditionalOperator;
  170. case Stmt::BinaryConditionalOperatorClass:
  171. return PGOHash::BinaryConditionalOperator;
  172. case Stmt::BinaryOperatorClass: {
  173. const BinaryOperator *BO = cast<BinaryOperator>(S);
  174. if (BO->getOpcode() == BO_LAnd)
  175. return PGOHash::BinaryOperatorLAnd;
  176. if (BO->getOpcode() == BO_LOr)
  177. return PGOHash::BinaryOperatorLOr;
  178. break;
  179. }
  180. }
  181. return PGOHash::None;
  182. }
  183. };
  184. /// A StmtVisitor that propagates the raw counts through the AST and
  185. /// records the count at statements where the value may change.
  186. struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
  187. /// PGO state.
  188. CodeGenPGO &PGO;
  189. /// A flag that is set when the current count should be recorded on the
  190. /// next statement, such as at the exit of a loop.
  191. bool RecordNextStmtCount;
  192. /// The count at the current location in the traversal.
  193. uint64_t CurrentCount;
  194. /// The map of statements to count values.
  195. llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
  196. /// BreakContinueStack - Keep counts of breaks and continues inside loops.
  197. struct BreakContinue {
  198. uint64_t BreakCount;
  199. uint64_t ContinueCount;
  200. BreakContinue() : BreakCount(0), ContinueCount(0) {}
  201. };
  202. SmallVector<BreakContinue, 8> BreakContinueStack;
  203. ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
  204. CodeGenPGO &PGO)
  205. : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
  206. void RecordStmtCount(const Stmt *S) {
  207. if (RecordNextStmtCount) {
  208. CountMap[S] = CurrentCount;
  209. RecordNextStmtCount = false;
  210. }
  211. }
  212. /// Set and return the current count.
  213. uint64_t setCount(uint64_t Count) {
  214. CurrentCount = Count;
  215. return Count;
  216. }
  217. void VisitStmt(const Stmt *S) {
  218. RecordStmtCount(S);
  219. for (const Stmt *Child : S->children())
  220. if (Child)
  221. this->Visit(Child);
  222. }
  223. void VisitFunctionDecl(const FunctionDecl *D) {
  224. // Counter tracks entry to the function body.
  225. uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
  226. CountMap[D->getBody()] = BodyCount;
  227. Visit(D->getBody());
  228. }
  229. // Skip lambda expressions. We visit these as FunctionDecls when we're
  230. // generating them and aren't interested in the body when generating a
  231. // parent context.
  232. void VisitLambdaExpr(const LambdaExpr *LE) {}
  233. void VisitCapturedDecl(const CapturedDecl *D) {
  234. // Counter tracks entry to the capture body.
  235. uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
  236. CountMap[D->getBody()] = BodyCount;
  237. Visit(D->getBody());
  238. }
  239. void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
  240. // Counter tracks entry to the method body.
  241. uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
  242. CountMap[D->getBody()] = BodyCount;
  243. Visit(D->getBody());
  244. }
  245. void VisitBlockDecl(const BlockDecl *D) {
  246. // Counter tracks entry to the block body.
  247. uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
  248. CountMap[D->getBody()] = BodyCount;
  249. Visit(D->getBody());
  250. }
  251. void VisitReturnStmt(const ReturnStmt *S) {
  252. RecordStmtCount(S);
  253. if (S->getRetValue())
  254. Visit(S->getRetValue());
  255. CurrentCount = 0;
  256. RecordNextStmtCount = true;
  257. }
  258. void VisitCXXThrowExpr(const CXXThrowExpr *E) {
  259. RecordStmtCount(E);
  260. if (E->getSubExpr())
  261. Visit(E->getSubExpr());
  262. CurrentCount = 0;
  263. RecordNextStmtCount = true;
  264. }
  265. void VisitGotoStmt(const GotoStmt *S) {
  266. RecordStmtCount(S);
  267. CurrentCount = 0;
  268. RecordNextStmtCount = true;
  269. }
  270. void VisitLabelStmt(const LabelStmt *S) {
  271. RecordNextStmtCount = false;
  272. // Counter tracks the block following the label.
  273. uint64_t BlockCount = setCount(PGO.getRegionCount(S));
  274. CountMap[S] = BlockCount;
  275. Visit(S->getSubStmt());
  276. }
  277. void VisitBreakStmt(const BreakStmt *S) {
  278. RecordStmtCount(S);
  279. assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
  280. BreakContinueStack.back().BreakCount += CurrentCount;
  281. CurrentCount = 0;
  282. RecordNextStmtCount = true;
  283. }
  284. void VisitContinueStmt(const ContinueStmt *S) {
  285. RecordStmtCount(S);
  286. assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
  287. BreakContinueStack.back().ContinueCount += CurrentCount;
  288. CurrentCount = 0;
  289. RecordNextStmtCount = true;
  290. }
  291. void VisitWhileStmt(const WhileStmt *S) {
  292. RecordStmtCount(S);
  293. uint64_t ParentCount = CurrentCount;
  294. BreakContinueStack.push_back(BreakContinue());
  295. // Visit the body region first so the break/continue adjustments can be
  296. // included when visiting the condition.
  297. uint64_t BodyCount = setCount(PGO.getRegionCount(S));
  298. CountMap[S->getBody()] = CurrentCount;
  299. Visit(S->getBody());
  300. uint64_t BackedgeCount = CurrentCount;
  301. // ...then go back and propagate counts through the condition. The count
  302. // at the start of the condition is the sum of the incoming edges,
  303. // the backedge from the end of the loop body, and the edges from
  304. // continue statements.
  305. BreakContinue BC = BreakContinueStack.pop_back_val();
  306. uint64_t CondCount =
  307. setCount(ParentCount + BackedgeCount + BC.ContinueCount);
  308. CountMap[S->getCond()] = CondCount;
  309. Visit(S->getCond());
  310. setCount(BC.BreakCount + CondCount - BodyCount);
  311. RecordNextStmtCount = true;
  312. }
  313. void VisitDoStmt(const DoStmt *S) {
  314. RecordStmtCount(S);
  315. uint64_t LoopCount = PGO.getRegionCount(S);
  316. BreakContinueStack.push_back(BreakContinue());
  317. // The count doesn't include the fallthrough from the parent scope. Add it.
  318. uint64_t BodyCount = setCount(LoopCount + CurrentCount);
  319. CountMap[S->getBody()] = BodyCount;
  320. Visit(S->getBody());
  321. uint64_t BackedgeCount = CurrentCount;
  322. BreakContinue BC = BreakContinueStack.pop_back_val();
  323. // The count at the start of the condition is equal to the count at the
  324. // end of the body, plus any continues.
  325. uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
  326. CountMap[S->getCond()] = CondCount;
  327. Visit(S->getCond());
  328. setCount(BC.BreakCount + CondCount - LoopCount);
  329. RecordNextStmtCount = true;
  330. }
  331. void VisitForStmt(const ForStmt *S) {
  332. RecordStmtCount(S);
  333. if (S->getInit())
  334. Visit(S->getInit());
  335. uint64_t ParentCount = CurrentCount;
  336. BreakContinueStack.push_back(BreakContinue());
  337. // Visit the body region first. (This is basically the same as a while
  338. // loop; see further comments in VisitWhileStmt.)
  339. uint64_t BodyCount = setCount(PGO.getRegionCount(S));
  340. CountMap[S->getBody()] = BodyCount;
  341. Visit(S->getBody());
  342. uint64_t BackedgeCount = CurrentCount;
  343. BreakContinue BC = BreakContinueStack.pop_back_val();
  344. // The increment is essentially part of the body but it needs to include
  345. // the count for all the continue statements.
  346. if (S->getInc()) {
  347. uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
  348. CountMap[S->getInc()] = IncCount;
  349. Visit(S->getInc());
  350. }
  351. // ...then go back and propagate counts through the condition.
  352. uint64_t CondCount =
  353. setCount(ParentCount + BackedgeCount + BC.ContinueCount);
  354. if (S->getCond()) {
  355. CountMap[S->getCond()] = CondCount;
  356. Visit(S->getCond());
  357. }
  358. setCount(BC.BreakCount + CondCount - BodyCount);
  359. RecordNextStmtCount = true;
  360. }
  361. void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
  362. RecordStmtCount(S);
  363. Visit(S->getLoopVarStmt());
  364. Visit(S->getRangeStmt());
  365. Visit(S->getBeginStmt());
  366. Visit(S->getEndStmt());
  367. uint64_t ParentCount = CurrentCount;
  368. BreakContinueStack.push_back(BreakContinue());
  369. // Visit the body region first. (This is basically the same as a while
  370. // loop; see further comments in VisitWhileStmt.)
  371. uint64_t BodyCount = setCount(PGO.getRegionCount(S));
  372. CountMap[S->getBody()] = BodyCount;
  373. Visit(S->getBody());
  374. uint64_t BackedgeCount = CurrentCount;
  375. BreakContinue BC = BreakContinueStack.pop_back_val();
  376. // The increment is essentially part of the body but it needs to include
  377. // the count for all the continue statements.
  378. uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
  379. CountMap[S->getInc()] = IncCount;
  380. Visit(S->getInc());
  381. // ...then go back and propagate counts through the condition.
  382. uint64_t CondCount =
  383. setCount(ParentCount + BackedgeCount + BC.ContinueCount);
  384. CountMap[S->getCond()] = CondCount;
  385. Visit(S->getCond());
  386. setCount(BC.BreakCount + CondCount - BodyCount);
  387. RecordNextStmtCount = true;
  388. }
  389. void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
  390. RecordStmtCount(S);
  391. Visit(S->getElement());
  392. uint64_t ParentCount = CurrentCount;
  393. BreakContinueStack.push_back(BreakContinue());
  394. // Counter tracks the body of the loop.
  395. uint64_t BodyCount = setCount(PGO.getRegionCount(S));
  396. CountMap[S->getBody()] = BodyCount;
  397. Visit(S->getBody());
  398. uint64_t BackedgeCount = CurrentCount;
  399. BreakContinue BC = BreakContinueStack.pop_back_val();
  400. setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
  401. BodyCount);
  402. RecordNextStmtCount = true;
  403. }
  404. void VisitSwitchStmt(const SwitchStmt *S) {
  405. RecordStmtCount(S);
  406. if (S->getInit())
  407. Visit(S->getInit());
  408. Visit(S->getCond());
  409. CurrentCount = 0;
  410. BreakContinueStack.push_back(BreakContinue());
  411. Visit(S->getBody());
  412. // If the switch is inside a loop, add the continue counts.
  413. BreakContinue BC = BreakContinueStack.pop_back_val();
  414. if (!BreakContinueStack.empty())
  415. BreakContinueStack.back().ContinueCount += BC.ContinueCount;
  416. // Counter tracks the exit block of the switch.
  417. setCount(PGO.getRegionCount(S));
  418. RecordNextStmtCount = true;
  419. }
  420. void VisitSwitchCase(const SwitchCase *S) {
  421. RecordNextStmtCount = false;
  422. // Counter for this particular case. This counts only jumps from the
  423. // switch header and does not include fallthrough from the case before
  424. // this one.
  425. uint64_t CaseCount = PGO.getRegionCount(S);
  426. setCount(CurrentCount + CaseCount);
  427. // We need the count without fallthrough in the mapping, so it's more useful
  428. // for branch probabilities.
  429. CountMap[S] = CaseCount;
  430. RecordNextStmtCount = true;
  431. Visit(S->getSubStmt());
  432. }
  433. void VisitIfStmt(const IfStmt *S) {
  434. RecordStmtCount(S);
  435. uint64_t ParentCount = CurrentCount;
  436. if (S->getInit())
  437. Visit(S->getInit());
  438. Visit(S->getCond());
  439. // Counter tracks the "then" part of an if statement. The count for
  440. // the "else" part, if it exists, will be calculated from this counter.
  441. uint64_t ThenCount = setCount(PGO.getRegionCount(S));
  442. CountMap[S->getThen()] = ThenCount;
  443. Visit(S->getThen());
  444. uint64_t OutCount = CurrentCount;
  445. uint64_t ElseCount = ParentCount - ThenCount;
  446. if (S->getElse()) {
  447. setCount(ElseCount);
  448. CountMap[S->getElse()] = ElseCount;
  449. Visit(S->getElse());
  450. OutCount += CurrentCount;
  451. } else
  452. OutCount += ElseCount;
  453. setCount(OutCount);
  454. RecordNextStmtCount = true;
  455. }
  456. void VisitCXXTryStmt(const CXXTryStmt *S) {
  457. RecordStmtCount(S);
  458. Visit(S->getTryBlock());
  459. for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
  460. Visit(S->getHandler(I));
  461. // Counter tracks the continuation block of the try statement.
  462. setCount(PGO.getRegionCount(S));
  463. RecordNextStmtCount = true;
  464. }
  465. void VisitCXXCatchStmt(const CXXCatchStmt *S) {
  466. RecordNextStmtCount = false;
  467. // Counter tracks the catch statement's handler block.
  468. uint64_t CatchCount = setCount(PGO.getRegionCount(S));
  469. CountMap[S] = CatchCount;
  470. Visit(S->getHandlerBlock());
  471. }
  472. void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
  473. RecordStmtCount(E);
  474. uint64_t ParentCount = CurrentCount;
  475. Visit(E->getCond());
  476. // Counter tracks the "true" part of a conditional operator. The
  477. // count in the "false" part will be calculated from this counter.
  478. uint64_t TrueCount = setCount(PGO.getRegionCount(E));
  479. CountMap[E->getTrueExpr()] = TrueCount;
  480. Visit(E->getTrueExpr());
  481. uint64_t OutCount = CurrentCount;
  482. uint64_t FalseCount = setCount(ParentCount - TrueCount);
  483. CountMap[E->getFalseExpr()] = FalseCount;
  484. Visit(E->getFalseExpr());
  485. OutCount += CurrentCount;
  486. setCount(OutCount);
  487. RecordNextStmtCount = true;
  488. }
  489. void VisitBinLAnd(const BinaryOperator *E) {
  490. RecordStmtCount(E);
  491. uint64_t ParentCount = CurrentCount;
  492. Visit(E->getLHS());
  493. // Counter tracks the right hand side of a logical and operator.
  494. uint64_t RHSCount = setCount(PGO.getRegionCount(E));
  495. CountMap[E->getRHS()] = RHSCount;
  496. Visit(E->getRHS());
  497. setCount(ParentCount + RHSCount - CurrentCount);
  498. RecordNextStmtCount = true;
  499. }
  500. void VisitBinLOr(const BinaryOperator *E) {
  501. RecordStmtCount(E);
  502. uint64_t ParentCount = CurrentCount;
  503. Visit(E->getLHS());
  504. // Counter tracks the right hand side of a logical or operator.
  505. uint64_t RHSCount = setCount(PGO.getRegionCount(E));
  506. CountMap[E->getRHS()] = RHSCount;
  507. Visit(E->getRHS());
  508. setCount(ParentCount + RHSCount - CurrentCount);
  509. RecordNextStmtCount = true;
  510. }
  511. };
  512. } // end anonymous namespace
  513. void PGOHash::combine(HashType Type) {
  514. // Check that we never combine 0 and only have six bits.
  515. assert(Type && "Hash is invalid: unexpected type 0");
  516. assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
  517. // Pass through MD5 if enough work has built up.
  518. if (Count && Count % NumTypesPerWord == 0) {
  519. using namespace llvm::support;
  520. uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
  521. MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
  522. Working = 0;
  523. }
  524. // Accumulate the current type.
  525. ++Count;
  526. Working = Working << NumBitsPerType | Type;
  527. }
  528. uint64_t PGOHash::finalize() {
  529. // Use Working as the hash directly if we never used MD5.
  530. if (Count <= NumTypesPerWord)
  531. // No need to byte swap here, since none of the math was endian-dependent.
  532. // This number will be byte-swapped as required on endianness transitions,
  533. // so we will see the same value on the other side.
  534. return Working;
  535. // Check for remaining work in Working.
  536. if (Working)
  537. MD5.update(Working);
  538. // Finalize the MD5 and return the hash.
  539. llvm::MD5::MD5Result Result;
  540. MD5.final(Result);
  541. using namespace llvm::support;
  542. return endian::read<uint64_t, little, unaligned>(Result);
  543. }
  544. void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
  545. const Decl *D = GD.getDecl();
  546. bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
  547. llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
  548. if (!InstrumentRegions && !PGOReader)
  549. return;
  550. if (D->isImplicit())
  551. return;
  552. // Constructors and destructors may be represented by several functions in IR.
  553. // If so, instrument only base variant, others are implemented by delegation
  554. // to the base one, it would be counted twice otherwise.
  555. if (CGM.getTarget().getCXXABI().hasConstructorVariants() &&
  556. ((isa<CXXConstructorDecl>(GD.getDecl()) &&
  557. GD.getCtorType() != Ctor_Base) ||
  558. (isa<CXXDestructorDecl>(GD.getDecl()) &&
  559. GD.getDtorType() != Dtor_Base))) {
  560. return;
  561. }
  562. CGM.ClearUnusedCoverageMapping(D);
  563. setFuncName(Fn);
  564. mapRegionCounters(D);
  565. if (CGM.getCodeGenOpts().CoverageMapping)
  566. emitCounterRegionMapping(D);
  567. if (PGOReader) {
  568. SourceManager &SM = CGM.getContext().getSourceManager();
  569. loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
  570. computeRegionCounts(D);
  571. applyFunctionAttributes(PGOReader, Fn);
  572. }
  573. }
  574. void CodeGenPGO::mapRegionCounters(const Decl *D) {
  575. RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
  576. MapRegionCounters Walker(*RegionCounterMap);
  577. if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
  578. Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
  579. else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
  580. Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
  581. else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
  582. Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
  583. else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
  584. Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
  585. assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
  586. NumRegionCounters = Walker.NextCounter;
  587. FunctionHash = Walker.Hash.finalize();
  588. }
  589. bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
  590. if (SkipCoverageMapping)
  591. return true;
  592. // Don't map the functions in system headers.
  593. const auto &SM = CGM.getContext().getSourceManager();
  594. auto Loc = D->getBody()->getLocStart();
  595. return SM.isInSystemHeader(Loc);
  596. }
  597. void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
  598. if (skipRegionMappingForDecl(D))
  599. return;
  600. std::string CoverageMapping;
  601. llvm::raw_string_ostream OS(CoverageMapping);
  602. CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
  603. CGM.getContext().getSourceManager(),
  604. CGM.getLangOpts(), RegionCounterMap.get());
  605. MappingGen.emitCounterMapping(D, OS);
  606. OS.flush();
  607. if (CoverageMapping.empty())
  608. return;
  609. CGM.getCoverageMapping()->addFunctionMappingRecord(
  610. FuncNameVar, FuncName, FunctionHash, CoverageMapping);
  611. }
  612. void
  613. CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
  614. llvm::GlobalValue::LinkageTypes Linkage) {
  615. if (skipRegionMappingForDecl(D))
  616. return;
  617. std::string CoverageMapping;
  618. llvm::raw_string_ostream OS(CoverageMapping);
  619. CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
  620. CGM.getContext().getSourceManager(),
  621. CGM.getLangOpts());
  622. MappingGen.emitEmptyMapping(D, OS);
  623. OS.flush();
  624. if (CoverageMapping.empty())
  625. return;
  626. setFuncName(Name, Linkage);
  627. CGM.getCoverageMapping()->addFunctionMappingRecord(
  628. FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
  629. }
  630. void CodeGenPGO::computeRegionCounts(const Decl *D) {
  631. StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
  632. ComputeRegionCounts Walker(*StmtCountMap, *this);
  633. if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
  634. Walker.VisitFunctionDecl(FD);
  635. else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
  636. Walker.VisitObjCMethodDecl(MD);
  637. else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
  638. Walker.VisitBlockDecl(BD);
  639. else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
  640. Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
  641. }
  642. void
  643. CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
  644. llvm::Function *Fn) {
  645. if (!haveRegionCounts())
  646. return;
  647. uint64_t FunctionCount = getRegionCount(nullptr);
  648. Fn->setEntryCount(FunctionCount);
  649. }
  650. void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
  651. if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
  652. return;
  653. if (!Builder.GetInsertBlock())
  654. return;
  655. unsigned Counter = (*RegionCounterMap)[S];
  656. auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
  657. Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
  658. {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
  659. Builder.getInt64(FunctionHash),
  660. Builder.getInt32(NumRegionCounters),
  661. Builder.getInt32(Counter)});
  662. }
  663. // This method either inserts a call to the profile run-time during
  664. // instrumentation or puts profile data into metadata for PGO use.
  665. void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
  666. llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
  667. if (!EnableValueProfiling)
  668. return;
  669. if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
  670. return;
  671. if (isa<llvm::Constant>(ValuePtr))
  672. return;
  673. bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
  674. if (InstrumentValueSites && RegionCounterMap) {
  675. auto BuilderInsertPoint = Builder.saveIP();
  676. Builder.SetInsertPoint(ValueSite);
  677. llvm::Value *Args[5] = {
  678. llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
  679. Builder.getInt64(FunctionHash),
  680. Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
  681. Builder.getInt32(ValueKind),
  682. Builder.getInt32(NumValueSites[ValueKind]++)
  683. };
  684. Builder.CreateCall(
  685. CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
  686. Builder.restoreIP(BuilderInsertPoint);
  687. return;
  688. }
  689. llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
  690. if (PGOReader && haveRegionCounts()) {
  691. // We record the top most called three functions at each call site.
  692. // Profile metadata contains "VP" string identifying this metadata
  693. // as value profiling data, then a uint32_t value for the value profiling
  694. // kind, a uint64_t value for the total number of times the call is
  695. // executed, followed by the function hash and execution count (uint64_t)
  696. // pairs for each function.
  697. if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
  698. return;
  699. llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
  700. (llvm::InstrProfValueKind)ValueKind,
  701. NumValueSites[ValueKind]);
  702. NumValueSites[ValueKind]++;
  703. }
  704. }
  705. void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
  706. bool IsInMainFile) {
  707. CGM.getPGOStats().addVisited(IsInMainFile);
  708. RegionCounts.clear();
  709. llvm::Expected<llvm::InstrProfRecord> RecordExpected =
  710. PGOReader->getInstrProfRecord(FuncName, FunctionHash);
  711. if (auto E = RecordExpected.takeError()) {
  712. auto IPE = llvm::InstrProfError::take(std::move(E));
  713. if (IPE == llvm::instrprof_error::unknown_function)
  714. CGM.getPGOStats().addMissing(IsInMainFile);
  715. else if (IPE == llvm::instrprof_error::hash_mismatch)
  716. CGM.getPGOStats().addMismatched(IsInMainFile);
  717. else if (IPE == llvm::instrprof_error::malformed)
  718. // TODO: Consider a more specific warning for this case.
  719. CGM.getPGOStats().addMismatched(IsInMainFile);
  720. return;
  721. }
  722. ProfRecord =
  723. llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
  724. RegionCounts = ProfRecord->Counts;
  725. }
  726. /// \brief Calculate what to divide by to scale weights.
  727. ///
  728. /// Given the maximum weight, calculate a divisor that will scale all the
  729. /// weights to strictly less than UINT32_MAX.
  730. static uint64_t calculateWeightScale(uint64_t MaxWeight) {
  731. return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
  732. }
  733. /// \brief Scale an individual branch weight (and add 1).
  734. ///
  735. /// Scale a 64-bit weight down to 32-bits using \c Scale.
  736. ///
  737. /// According to Laplace's Rule of Succession, it is better to compute the
  738. /// weight based on the count plus 1, so universally add 1 to the value.
  739. ///
  740. /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
  741. /// greater than \c Weight.
  742. static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
  743. assert(Scale && "scale by 0?");
  744. uint64_t Scaled = Weight / Scale + 1;
  745. assert(Scaled <= UINT32_MAX && "overflow 32-bits");
  746. return Scaled;
  747. }
  748. llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
  749. uint64_t FalseCount) {
  750. // Check for empty weights.
  751. if (!TrueCount && !FalseCount)
  752. return nullptr;
  753. // Calculate how to scale down to 32-bits.
  754. uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
  755. llvm::MDBuilder MDHelper(CGM.getLLVMContext());
  756. return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
  757. scaleBranchWeight(FalseCount, Scale));
  758. }
  759. llvm::MDNode *
  760. CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
  761. // We need at least two elements to create meaningful weights.
  762. if (Weights.size() < 2)
  763. return nullptr;
  764. // Check for empty weights.
  765. uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
  766. if (MaxWeight == 0)
  767. return nullptr;
  768. // Calculate how to scale down to 32-bits.
  769. uint64_t Scale = calculateWeightScale(MaxWeight);
  770. SmallVector<uint32_t, 16> ScaledWeights;
  771. ScaledWeights.reserve(Weights.size());
  772. for (uint64_t W : Weights)
  773. ScaledWeights.push_back(scaleBranchWeight(W, Scale));
  774. llvm::MDBuilder MDHelper(CGM.getLLVMContext());
  775. return MDHelper.createBranchWeights(ScaledWeights);
  776. }
  777. llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
  778. uint64_t LoopCount) {
  779. if (!PGO.haveRegionCounts())
  780. return nullptr;
  781. Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
  782. assert(CondCount.hasValue() && "missing expected loop condition count");
  783. if (*CondCount == 0)
  784. return nullptr;
  785. return createProfileWeights(LoopCount,
  786. std::max(*CondCount, LoopCount) - LoopCount);
  787. }