CodeGenPGO.cpp 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Instrumentation-based profile-guided optimization
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "CodeGenPGO.h"
  13. #include "CodeGenFunction.h"
  14. #include "CoverageMappingGen.h"
  15. #include "clang/AST/RecursiveASTVisitor.h"
  16. #include "clang/AST/StmtVisitor.h"
  17. #include "llvm/IR/Intrinsics.h"
  18. #include "llvm/IR/MDBuilder.h"
  19. #include "llvm/Support/Endian.h"
  20. #include "llvm/Support/FileSystem.h"
  21. #include "llvm/Support/MD5.h"
  22. static llvm::cl::opt<bool>
  23. EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
  24. llvm::cl::desc("Enable value profiling"),
  25. llvm::cl::Hidden, llvm::cl::init(false));
  26. using namespace clang;
  27. using namespace CodeGen;
  28. void CodeGenPGO::setFuncName(StringRef Name,
  29. llvm::GlobalValue::LinkageTypes Linkage) {
  30. llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
  31. FuncName = llvm::getPGOFuncName(
  32. Name, Linkage, CGM.getCodeGenOpts().MainFileName,
  33. PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
  34. // If we're generating a profile, create a variable for the name.
  35. if (CGM.getCodeGenOpts().hasProfileClangInstr())
  36. FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
  37. }
  38. void CodeGenPGO::setFuncName(llvm::Function *Fn) {
  39. setFuncName(Fn->getName(), Fn->getLinkage());
  40. // Create PGOFuncName meta data.
  41. llvm::createPGOFuncNameMetadata(*Fn, FuncName);
  42. }
  43. /// The version of the PGO hash algorithm.
  44. enum PGOHashVersion : unsigned {
  45. PGO_HASH_V1,
  46. PGO_HASH_V2,
  47. // Keep this set to the latest hash version.
  48. PGO_HASH_LATEST = PGO_HASH_V2
  49. };
  50. namespace {
  51. /// Stable hasher for PGO region counters.
  52. ///
  53. /// PGOHash produces a stable hash of a given function's control flow.
  54. ///
  55. /// Changing the output of this hash will invalidate all previously generated
  56. /// profiles -- i.e., don't do it.
  57. ///
  58. /// \note When this hash does eventually change (years?), we still need to
  59. /// support old hashes. We'll need to pull in the version number from the
  60. /// profile data format and use the matching hash function.
  61. class PGOHash {
  62. uint64_t Working;
  63. unsigned Count;
  64. PGOHashVersion HashVersion;
  65. llvm::MD5 MD5;
  66. static const int NumBitsPerType = 6;
  67. static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
  68. static const unsigned TooBig = 1u << NumBitsPerType;
  69. public:
  70. /// Hash values for AST nodes.
  71. ///
  72. /// Distinct values for AST nodes that have region counters attached.
  73. ///
  74. /// These values must be stable. All new members must be added at the end,
  75. /// and no members should be removed. Changing the enumeration value for an
  76. /// AST node will affect the hash of every function that contains that node.
  77. enum HashType : unsigned char {
  78. None = 0,
  79. LabelStmt = 1,
  80. WhileStmt,
  81. DoStmt,
  82. ForStmt,
  83. CXXForRangeStmt,
  84. ObjCForCollectionStmt,
  85. SwitchStmt,
  86. CaseStmt,
  87. DefaultStmt,
  88. IfStmt,
  89. CXXTryStmt,
  90. CXXCatchStmt,
  91. ConditionalOperator,
  92. BinaryOperatorLAnd,
  93. BinaryOperatorLOr,
  94. BinaryConditionalOperator,
  95. // The preceding values are available with PGO_HASH_V1.
  96. EndOfScope,
  97. IfThenBranch,
  98. IfElseBranch,
  99. GotoStmt,
  100. IndirectGotoStmt,
  101. BreakStmt,
  102. ContinueStmt,
  103. ReturnStmt,
  104. ThrowExpr,
  105. UnaryOperatorLNot,
  106. BinaryOperatorLT,
  107. BinaryOperatorGT,
  108. BinaryOperatorLE,
  109. BinaryOperatorGE,
  110. BinaryOperatorEQ,
  111. BinaryOperatorNE,
  112. // The preceding values are available with PGO_HASH_V2.
  113. // Keep this last. It's for the static assert that follows.
  114. LastHashType
  115. };
  116. static_assert(LastHashType <= TooBig, "Too many types in HashType");
  117. PGOHash(PGOHashVersion HashVersion)
  118. : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
  119. void combine(HashType Type);
  120. uint64_t finalize();
  121. PGOHashVersion getHashVersion() const { return HashVersion; }
  122. };
  123. const int PGOHash::NumBitsPerType;
  124. const unsigned PGOHash::NumTypesPerWord;
  125. const unsigned PGOHash::TooBig;
  126. /// Get the PGO hash version used in the given indexed profile.
  127. static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
  128. CodeGenModule &CGM) {
  129. if (PGOReader->getVersion() <= 4)
  130. return PGO_HASH_V1;
  131. return PGO_HASH_V2;
  132. }
  133. /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
  134. struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
  135. using Base = RecursiveASTVisitor<MapRegionCounters>;
  136. /// The next counter value to assign.
  137. unsigned NextCounter;
  138. /// The function hash.
  139. PGOHash Hash;
  140. /// The map of statements to counters.
  141. llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
  142. MapRegionCounters(PGOHashVersion HashVersion,
  143. llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
  144. : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
  145. // Blocks and lambdas are handled as separate functions, so we need not
  146. // traverse them in the parent context.
  147. bool TraverseBlockExpr(BlockExpr *BE) { return true; }
  148. bool TraverseLambdaExpr(LambdaExpr *LE) {
  149. // Traverse the captures, but not the body.
  150. for (const auto &C : zip(LE->captures(), LE->capture_inits()))
  151. TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
  152. return true;
  153. }
  154. bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
  155. bool VisitDecl(const Decl *D) {
  156. switch (D->getKind()) {
  157. default:
  158. break;
  159. case Decl::Function:
  160. case Decl::CXXMethod:
  161. case Decl::CXXConstructor:
  162. case Decl::CXXDestructor:
  163. case Decl::CXXConversion:
  164. case Decl::ObjCMethod:
  165. case Decl::Block:
  166. case Decl::Captured:
  167. CounterMap[D->getBody()] = NextCounter++;
  168. break;
  169. }
  170. return true;
  171. }
  172. /// If \p S gets a fresh counter, update the counter mappings. Return the
  173. /// V1 hash of \p S.
  174. PGOHash::HashType updateCounterMappings(Stmt *S) {
  175. auto Type = getHashType(PGO_HASH_V1, S);
  176. if (Type != PGOHash::None)
  177. CounterMap[S] = NextCounter++;
  178. return Type;
  179. }
  180. /// Include \p S in the function hash.
  181. bool VisitStmt(Stmt *S) {
  182. auto Type = updateCounterMappings(S);
  183. if (Hash.getHashVersion() != PGO_HASH_V1)
  184. Type = getHashType(Hash.getHashVersion(), S);
  185. if (Type != PGOHash::None)
  186. Hash.combine(Type);
  187. return true;
  188. }
  189. bool TraverseIfStmt(IfStmt *If) {
  190. // If we used the V1 hash, use the default traversal.
  191. if (Hash.getHashVersion() == PGO_HASH_V1)
  192. return Base::TraverseIfStmt(If);
  193. // Otherwise, keep track of which branch we're in while traversing.
  194. VisitStmt(If);
  195. for (Stmt *CS : If->children()) {
  196. if (!CS)
  197. continue;
  198. if (CS == If->getThen())
  199. Hash.combine(PGOHash::IfThenBranch);
  200. else if (CS == If->getElse())
  201. Hash.combine(PGOHash::IfElseBranch);
  202. TraverseStmt(CS);
  203. }
  204. Hash.combine(PGOHash::EndOfScope);
  205. return true;
  206. }
  207. // If the statement type \p N is nestable, and its nesting impacts profile
  208. // stability, define a custom traversal which tracks the end of the statement
  209. // in the hash (provided we're not using the V1 hash).
  210. #define DEFINE_NESTABLE_TRAVERSAL(N) \
  211. bool Traverse##N(N *S) { \
  212. Base::Traverse##N(S); \
  213. if (Hash.getHashVersion() != PGO_HASH_V1) \
  214. Hash.combine(PGOHash::EndOfScope); \
  215. return true; \
  216. }
  217. DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
  218. DEFINE_NESTABLE_TRAVERSAL(DoStmt)
  219. DEFINE_NESTABLE_TRAVERSAL(ForStmt)
  220. DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
  221. DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
  222. DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
  223. DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
  224. /// Get version \p HashVersion of the PGO hash for \p S.
  225. PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
  226. switch (S->getStmtClass()) {
  227. default:
  228. break;
  229. case Stmt::LabelStmtClass:
  230. return PGOHash::LabelStmt;
  231. case Stmt::WhileStmtClass:
  232. return PGOHash::WhileStmt;
  233. case Stmt::DoStmtClass:
  234. return PGOHash::DoStmt;
  235. case Stmt::ForStmtClass:
  236. return PGOHash::ForStmt;
  237. case Stmt::CXXForRangeStmtClass:
  238. return PGOHash::CXXForRangeStmt;
  239. case Stmt::ObjCForCollectionStmtClass:
  240. return PGOHash::ObjCForCollectionStmt;
  241. case Stmt::SwitchStmtClass:
  242. return PGOHash::SwitchStmt;
  243. case Stmt::CaseStmtClass:
  244. return PGOHash::CaseStmt;
  245. case Stmt::DefaultStmtClass:
  246. return PGOHash::DefaultStmt;
  247. case Stmt::IfStmtClass:
  248. return PGOHash::IfStmt;
  249. case Stmt::CXXTryStmtClass:
  250. return PGOHash::CXXTryStmt;
  251. case Stmt::CXXCatchStmtClass:
  252. return PGOHash::CXXCatchStmt;
  253. case Stmt::ConditionalOperatorClass:
  254. return PGOHash::ConditionalOperator;
  255. case Stmt::BinaryConditionalOperatorClass:
  256. return PGOHash::BinaryConditionalOperator;
  257. case Stmt::BinaryOperatorClass: {
  258. const BinaryOperator *BO = cast<BinaryOperator>(S);
  259. if (BO->getOpcode() == BO_LAnd)
  260. return PGOHash::BinaryOperatorLAnd;
  261. if (BO->getOpcode() == BO_LOr)
  262. return PGOHash::BinaryOperatorLOr;
  263. if (HashVersion == PGO_HASH_V2) {
  264. switch (BO->getOpcode()) {
  265. default:
  266. break;
  267. case BO_LT:
  268. return PGOHash::BinaryOperatorLT;
  269. case BO_GT:
  270. return PGOHash::BinaryOperatorGT;
  271. case BO_LE:
  272. return PGOHash::BinaryOperatorLE;
  273. case BO_GE:
  274. return PGOHash::BinaryOperatorGE;
  275. case BO_EQ:
  276. return PGOHash::BinaryOperatorEQ;
  277. case BO_NE:
  278. return PGOHash::BinaryOperatorNE;
  279. }
  280. }
  281. break;
  282. }
  283. }
  284. if (HashVersion == PGO_HASH_V2) {
  285. switch (S->getStmtClass()) {
  286. default:
  287. break;
  288. case Stmt::GotoStmtClass:
  289. return PGOHash::GotoStmt;
  290. case Stmt::IndirectGotoStmtClass:
  291. return PGOHash::IndirectGotoStmt;
  292. case Stmt::BreakStmtClass:
  293. return PGOHash::BreakStmt;
  294. case Stmt::ContinueStmtClass:
  295. return PGOHash::ContinueStmt;
  296. case Stmt::ReturnStmtClass:
  297. return PGOHash::ReturnStmt;
  298. case Stmt::CXXThrowExprClass:
  299. return PGOHash::ThrowExpr;
  300. case Stmt::UnaryOperatorClass: {
  301. const UnaryOperator *UO = cast<UnaryOperator>(S);
  302. if (UO->getOpcode() == UO_LNot)
  303. return PGOHash::UnaryOperatorLNot;
  304. break;
  305. }
  306. }
  307. }
  308. return PGOHash::None;
  309. }
  310. };
  311. /// A StmtVisitor that propagates the raw counts through the AST and
  312. /// records the count at statements where the value may change.
  313. struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
  314. /// PGO state.
  315. CodeGenPGO &PGO;
  316. /// A flag that is set when the current count should be recorded on the
  317. /// next statement, such as at the exit of a loop.
  318. bool RecordNextStmtCount;
  319. /// The count at the current location in the traversal.
  320. uint64_t CurrentCount;
  321. /// The map of statements to count values.
  322. llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
  323. /// BreakContinueStack - Keep counts of breaks and continues inside loops.
  324. struct BreakContinue {
  325. uint64_t BreakCount;
  326. uint64_t ContinueCount;
  327. BreakContinue() : BreakCount(0), ContinueCount(0) {}
  328. };
  329. SmallVector<BreakContinue, 8> BreakContinueStack;
  330. ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
  331. CodeGenPGO &PGO)
  332. : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
  333. void RecordStmtCount(const Stmt *S) {
  334. if (RecordNextStmtCount) {
  335. CountMap[S] = CurrentCount;
  336. RecordNextStmtCount = false;
  337. }
  338. }
  339. /// Set and return the current count.
  340. uint64_t setCount(uint64_t Count) {
  341. CurrentCount = Count;
  342. return Count;
  343. }
  344. void VisitStmt(const Stmt *S) {
  345. RecordStmtCount(S);
  346. for (const Stmt *Child : S->children())
  347. if (Child)
  348. this->Visit(Child);
  349. }
  350. void VisitFunctionDecl(const FunctionDecl *D) {
  351. // Counter tracks entry to the function body.
  352. uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
  353. CountMap[D->getBody()] = BodyCount;
  354. Visit(D->getBody());
  355. }
  356. // Skip lambda expressions. We visit these as FunctionDecls when we're
  357. // generating them and aren't interested in the body when generating a
  358. // parent context.
  359. void VisitLambdaExpr(const LambdaExpr *LE) {}
  360. void VisitCapturedDecl(const CapturedDecl *D) {
  361. // Counter tracks entry to the capture body.
  362. uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
  363. CountMap[D->getBody()] = BodyCount;
  364. Visit(D->getBody());
  365. }
  366. void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
  367. // Counter tracks entry to the method body.
  368. uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
  369. CountMap[D->getBody()] = BodyCount;
  370. Visit(D->getBody());
  371. }
  372. void VisitBlockDecl(const BlockDecl *D) {
  373. // Counter tracks entry to the block body.
  374. uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
  375. CountMap[D->getBody()] = BodyCount;
  376. Visit(D->getBody());
  377. }
  378. void VisitReturnStmt(const ReturnStmt *S) {
  379. RecordStmtCount(S);
  380. if (S->getRetValue())
  381. Visit(S->getRetValue());
  382. CurrentCount = 0;
  383. RecordNextStmtCount = true;
  384. }
  385. void VisitCXXThrowExpr(const CXXThrowExpr *E) {
  386. RecordStmtCount(E);
  387. if (E->getSubExpr())
  388. Visit(E->getSubExpr());
  389. CurrentCount = 0;
  390. RecordNextStmtCount = true;
  391. }
  392. void VisitGotoStmt(const GotoStmt *S) {
  393. RecordStmtCount(S);
  394. CurrentCount = 0;
  395. RecordNextStmtCount = true;
  396. }
  397. void VisitLabelStmt(const LabelStmt *S) {
  398. RecordNextStmtCount = false;
  399. // Counter tracks the block following the label.
  400. uint64_t BlockCount = setCount(PGO.getRegionCount(S));
  401. CountMap[S] = BlockCount;
  402. Visit(S->getSubStmt());
  403. }
  404. void VisitBreakStmt(const BreakStmt *S) {
  405. RecordStmtCount(S);
  406. assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
  407. BreakContinueStack.back().BreakCount += CurrentCount;
  408. CurrentCount = 0;
  409. RecordNextStmtCount = true;
  410. }
  411. void VisitContinueStmt(const ContinueStmt *S) {
  412. RecordStmtCount(S);
  413. assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
  414. BreakContinueStack.back().ContinueCount += CurrentCount;
  415. CurrentCount = 0;
  416. RecordNextStmtCount = true;
  417. }
  418. void VisitWhileStmt(const WhileStmt *S) {
  419. RecordStmtCount(S);
  420. uint64_t ParentCount = CurrentCount;
  421. BreakContinueStack.push_back(BreakContinue());
  422. // Visit the body region first so the break/continue adjustments can be
  423. // included when visiting the condition.
  424. uint64_t BodyCount = setCount(PGO.getRegionCount(S));
  425. CountMap[S->getBody()] = CurrentCount;
  426. Visit(S->getBody());
  427. uint64_t BackedgeCount = CurrentCount;
  428. // ...then go back and propagate counts through the condition. The count
  429. // at the start of the condition is the sum of the incoming edges,
  430. // the backedge from the end of the loop body, and the edges from
  431. // continue statements.
  432. BreakContinue BC = BreakContinueStack.pop_back_val();
  433. uint64_t CondCount =
  434. setCount(ParentCount + BackedgeCount + BC.ContinueCount);
  435. CountMap[S->getCond()] = CondCount;
  436. Visit(S->getCond());
  437. setCount(BC.BreakCount + CondCount - BodyCount);
  438. RecordNextStmtCount = true;
  439. }
  440. void VisitDoStmt(const DoStmt *S) {
  441. RecordStmtCount(S);
  442. uint64_t LoopCount = PGO.getRegionCount(S);
  443. BreakContinueStack.push_back(BreakContinue());
  444. // The count doesn't include the fallthrough from the parent scope. Add it.
  445. uint64_t BodyCount = setCount(LoopCount + CurrentCount);
  446. CountMap[S->getBody()] = BodyCount;
  447. Visit(S->getBody());
  448. uint64_t BackedgeCount = CurrentCount;
  449. BreakContinue BC = BreakContinueStack.pop_back_val();
  450. // The count at the start of the condition is equal to the count at the
  451. // end of the body, plus any continues.
  452. uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
  453. CountMap[S->getCond()] = CondCount;
  454. Visit(S->getCond());
  455. setCount(BC.BreakCount + CondCount - LoopCount);
  456. RecordNextStmtCount = true;
  457. }
  458. void VisitForStmt(const ForStmt *S) {
  459. RecordStmtCount(S);
  460. if (S->getInit())
  461. Visit(S->getInit());
  462. uint64_t ParentCount = CurrentCount;
  463. BreakContinueStack.push_back(BreakContinue());
  464. // Visit the body region first. (This is basically the same as a while
  465. // loop; see further comments in VisitWhileStmt.)
  466. uint64_t BodyCount = setCount(PGO.getRegionCount(S));
  467. CountMap[S->getBody()] = BodyCount;
  468. Visit(S->getBody());
  469. uint64_t BackedgeCount = CurrentCount;
  470. BreakContinue BC = BreakContinueStack.pop_back_val();
  471. // The increment is essentially part of the body but it needs to include
  472. // the count for all the continue statements.
  473. if (S->getInc()) {
  474. uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
  475. CountMap[S->getInc()] = IncCount;
  476. Visit(S->getInc());
  477. }
  478. // ...then go back and propagate counts through the condition.
  479. uint64_t CondCount =
  480. setCount(ParentCount + BackedgeCount + BC.ContinueCount);
  481. if (S->getCond()) {
  482. CountMap[S->getCond()] = CondCount;
  483. Visit(S->getCond());
  484. }
  485. setCount(BC.BreakCount + CondCount - BodyCount);
  486. RecordNextStmtCount = true;
  487. }
  488. void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
  489. RecordStmtCount(S);
  490. if (S->getInit())
  491. Visit(S->getInit());
  492. Visit(S->getLoopVarStmt());
  493. Visit(S->getRangeStmt());
  494. Visit(S->getBeginStmt());
  495. Visit(S->getEndStmt());
  496. uint64_t ParentCount = CurrentCount;
  497. BreakContinueStack.push_back(BreakContinue());
  498. // Visit the body region first. (This is basically the same as a while
  499. // loop; see further comments in VisitWhileStmt.)
  500. uint64_t BodyCount = setCount(PGO.getRegionCount(S));
  501. CountMap[S->getBody()] = BodyCount;
  502. Visit(S->getBody());
  503. uint64_t BackedgeCount = CurrentCount;
  504. BreakContinue BC = BreakContinueStack.pop_back_val();
  505. // The increment is essentially part of the body but it needs to include
  506. // the count for all the continue statements.
  507. uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
  508. CountMap[S->getInc()] = IncCount;
  509. Visit(S->getInc());
  510. // ...then go back and propagate counts through the condition.
  511. uint64_t CondCount =
  512. setCount(ParentCount + BackedgeCount + BC.ContinueCount);
  513. CountMap[S->getCond()] = CondCount;
  514. Visit(S->getCond());
  515. setCount(BC.BreakCount + CondCount - BodyCount);
  516. RecordNextStmtCount = true;
  517. }
  518. void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
  519. RecordStmtCount(S);
  520. Visit(S->getElement());
  521. uint64_t ParentCount = CurrentCount;
  522. BreakContinueStack.push_back(BreakContinue());
  523. // Counter tracks the body of the loop.
  524. uint64_t BodyCount = setCount(PGO.getRegionCount(S));
  525. CountMap[S->getBody()] = BodyCount;
  526. Visit(S->getBody());
  527. uint64_t BackedgeCount = CurrentCount;
  528. BreakContinue BC = BreakContinueStack.pop_back_val();
  529. setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
  530. BodyCount);
  531. RecordNextStmtCount = true;
  532. }
  533. void VisitSwitchStmt(const SwitchStmt *S) {
  534. RecordStmtCount(S);
  535. if (S->getInit())
  536. Visit(S->getInit());
  537. Visit(S->getCond());
  538. CurrentCount = 0;
  539. BreakContinueStack.push_back(BreakContinue());
  540. Visit(S->getBody());
  541. // If the switch is inside a loop, add the continue counts.
  542. BreakContinue BC = BreakContinueStack.pop_back_val();
  543. if (!BreakContinueStack.empty())
  544. BreakContinueStack.back().ContinueCount += BC.ContinueCount;
  545. // Counter tracks the exit block of the switch.
  546. setCount(PGO.getRegionCount(S));
  547. RecordNextStmtCount = true;
  548. }
  549. void VisitSwitchCase(const SwitchCase *S) {
  550. RecordNextStmtCount = false;
  551. // Counter for this particular case. This counts only jumps from the
  552. // switch header and does not include fallthrough from the case before
  553. // this one.
  554. uint64_t CaseCount = PGO.getRegionCount(S);
  555. setCount(CurrentCount + CaseCount);
  556. // We need the count without fallthrough in the mapping, so it's more useful
  557. // for branch probabilities.
  558. CountMap[S] = CaseCount;
  559. RecordNextStmtCount = true;
  560. Visit(S->getSubStmt());
  561. }
  562. void VisitIfStmt(const IfStmt *S) {
  563. RecordStmtCount(S);
  564. uint64_t ParentCount = CurrentCount;
  565. if (S->getInit())
  566. Visit(S->getInit());
  567. Visit(S->getCond());
  568. // Counter tracks the "then" part of an if statement. The count for
  569. // the "else" part, if it exists, will be calculated from this counter.
  570. uint64_t ThenCount = setCount(PGO.getRegionCount(S));
  571. CountMap[S->getThen()] = ThenCount;
  572. Visit(S->getThen());
  573. uint64_t OutCount = CurrentCount;
  574. uint64_t ElseCount = ParentCount - ThenCount;
  575. if (S->getElse()) {
  576. setCount(ElseCount);
  577. CountMap[S->getElse()] = ElseCount;
  578. Visit(S->getElse());
  579. OutCount += CurrentCount;
  580. } else
  581. OutCount += ElseCount;
  582. setCount(OutCount);
  583. RecordNextStmtCount = true;
  584. }
  585. void VisitCXXTryStmt(const CXXTryStmt *S) {
  586. RecordStmtCount(S);
  587. Visit(S->getTryBlock());
  588. for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
  589. Visit(S->getHandler(I));
  590. // Counter tracks the continuation block of the try statement.
  591. setCount(PGO.getRegionCount(S));
  592. RecordNextStmtCount = true;
  593. }
  594. void VisitCXXCatchStmt(const CXXCatchStmt *S) {
  595. RecordNextStmtCount = false;
  596. // Counter tracks the catch statement's handler block.
  597. uint64_t CatchCount = setCount(PGO.getRegionCount(S));
  598. CountMap[S] = CatchCount;
  599. Visit(S->getHandlerBlock());
  600. }
  601. void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
  602. RecordStmtCount(E);
  603. uint64_t ParentCount = CurrentCount;
  604. Visit(E->getCond());
  605. // Counter tracks the "true" part of a conditional operator. The
  606. // count in the "false" part will be calculated from this counter.
  607. uint64_t TrueCount = setCount(PGO.getRegionCount(E));
  608. CountMap[E->getTrueExpr()] = TrueCount;
  609. Visit(E->getTrueExpr());
  610. uint64_t OutCount = CurrentCount;
  611. uint64_t FalseCount = setCount(ParentCount - TrueCount);
  612. CountMap[E->getFalseExpr()] = FalseCount;
  613. Visit(E->getFalseExpr());
  614. OutCount += CurrentCount;
  615. setCount(OutCount);
  616. RecordNextStmtCount = true;
  617. }
  618. void VisitBinLAnd(const BinaryOperator *E) {
  619. RecordStmtCount(E);
  620. uint64_t ParentCount = CurrentCount;
  621. Visit(E->getLHS());
  622. // Counter tracks the right hand side of a logical and operator.
  623. uint64_t RHSCount = setCount(PGO.getRegionCount(E));
  624. CountMap[E->getRHS()] = RHSCount;
  625. Visit(E->getRHS());
  626. setCount(ParentCount + RHSCount - CurrentCount);
  627. RecordNextStmtCount = true;
  628. }
  629. void VisitBinLOr(const BinaryOperator *E) {
  630. RecordStmtCount(E);
  631. uint64_t ParentCount = CurrentCount;
  632. Visit(E->getLHS());
  633. // Counter tracks the right hand side of a logical or operator.
  634. uint64_t RHSCount = setCount(PGO.getRegionCount(E));
  635. CountMap[E->getRHS()] = RHSCount;
  636. Visit(E->getRHS());
  637. setCount(ParentCount + RHSCount - CurrentCount);
  638. RecordNextStmtCount = true;
  639. }
  640. };
  641. } // end anonymous namespace
  642. void PGOHash::combine(HashType Type) {
  643. // Check that we never combine 0 and only have six bits.
  644. assert(Type && "Hash is invalid: unexpected type 0");
  645. assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
  646. // Pass through MD5 if enough work has built up.
  647. if (Count && Count % NumTypesPerWord == 0) {
  648. using namespace llvm::support;
  649. uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
  650. MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
  651. Working = 0;
  652. }
  653. // Accumulate the current type.
  654. ++Count;
  655. Working = Working << NumBitsPerType | Type;
  656. }
  657. uint64_t PGOHash::finalize() {
  658. // Use Working as the hash directly if we never used MD5.
  659. if (Count <= NumTypesPerWord)
  660. // No need to byte swap here, since none of the math was endian-dependent.
  661. // This number will be byte-swapped as required on endianness transitions,
  662. // so we will see the same value on the other side.
  663. return Working;
  664. // Check for remaining work in Working.
  665. if (Working)
  666. MD5.update(Working);
  667. // Finalize the MD5 and return the hash.
  668. llvm::MD5::MD5Result Result;
  669. MD5.final(Result);
  670. using namespace llvm::support;
  671. return Result.low();
  672. }
  673. void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
  674. const Decl *D = GD.getDecl();
  675. if (!D->hasBody())
  676. return;
  677. bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
  678. llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
  679. if (!InstrumentRegions && !PGOReader)
  680. return;
  681. if (D->isImplicit())
  682. return;
  683. // Constructors and destructors may be represented by several functions in IR.
  684. // If so, instrument only base variant, others are implemented by delegation
  685. // to the base one, it would be counted twice otherwise.
  686. if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
  687. if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
  688. if (GD.getCtorType() != Ctor_Base &&
  689. CodeGenFunction::IsConstructorDelegationValid(CCD))
  690. return;
  691. }
  692. if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
  693. return;
  694. CGM.ClearUnusedCoverageMapping(D);
  695. setFuncName(Fn);
  696. mapRegionCounters(D);
  697. if (CGM.getCodeGenOpts().CoverageMapping)
  698. emitCounterRegionMapping(D);
  699. if (PGOReader) {
  700. SourceManager &SM = CGM.getContext().getSourceManager();
  701. loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
  702. computeRegionCounts(D);
  703. applyFunctionAttributes(PGOReader, Fn);
  704. }
  705. }
  706. void CodeGenPGO::mapRegionCounters(const Decl *D) {
  707. // Use the latest hash version when inserting instrumentation, but use the
  708. // version in the indexed profile if we're reading PGO data.
  709. PGOHashVersion HashVersion = PGO_HASH_LATEST;
  710. if (auto *PGOReader = CGM.getPGOReader())
  711. HashVersion = getPGOHashVersion(PGOReader, CGM);
  712. RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
  713. MapRegionCounters Walker(HashVersion, *RegionCounterMap);
  714. if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
  715. Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
  716. else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
  717. Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
  718. else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
  719. Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
  720. else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
  721. Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
  722. assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
  723. NumRegionCounters = Walker.NextCounter;
  724. FunctionHash = Walker.Hash.finalize();
  725. }
  726. bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
  727. if (!D->getBody())
  728. return true;
  729. // Don't map the functions in system headers.
  730. const auto &SM = CGM.getContext().getSourceManager();
  731. auto Loc = D->getBody()->getBeginLoc();
  732. return SM.isInSystemHeader(Loc);
  733. }
  734. void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
  735. if (skipRegionMappingForDecl(D))
  736. return;
  737. std::string CoverageMapping;
  738. llvm::raw_string_ostream OS(CoverageMapping);
  739. CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
  740. CGM.getContext().getSourceManager(),
  741. CGM.getLangOpts(), RegionCounterMap.get());
  742. MappingGen.emitCounterMapping(D, OS);
  743. OS.flush();
  744. if (CoverageMapping.empty())
  745. return;
  746. CGM.getCoverageMapping()->addFunctionMappingRecord(
  747. FuncNameVar, FuncName, FunctionHash, CoverageMapping);
  748. }
  749. void
  750. CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
  751. llvm::GlobalValue::LinkageTypes Linkage) {
  752. if (skipRegionMappingForDecl(D))
  753. return;
  754. std::string CoverageMapping;
  755. llvm::raw_string_ostream OS(CoverageMapping);
  756. CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
  757. CGM.getContext().getSourceManager(),
  758. CGM.getLangOpts());
  759. MappingGen.emitEmptyMapping(D, OS);
  760. OS.flush();
  761. if (CoverageMapping.empty())
  762. return;
  763. setFuncName(Name, Linkage);
  764. CGM.getCoverageMapping()->addFunctionMappingRecord(
  765. FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
  766. }
  767. void CodeGenPGO::computeRegionCounts(const Decl *D) {
  768. StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
  769. ComputeRegionCounts Walker(*StmtCountMap, *this);
  770. if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
  771. Walker.VisitFunctionDecl(FD);
  772. else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
  773. Walker.VisitObjCMethodDecl(MD);
  774. else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
  775. Walker.VisitBlockDecl(BD);
  776. else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
  777. Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
  778. }
  779. void
  780. CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
  781. llvm::Function *Fn) {
  782. if (!haveRegionCounts())
  783. return;
  784. uint64_t FunctionCount = getRegionCount(nullptr);
  785. Fn->setEntryCount(FunctionCount);
  786. }
  787. void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
  788. llvm::Value *StepV) {
  789. if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
  790. return;
  791. if (!Builder.GetInsertBlock())
  792. return;
  793. unsigned Counter = (*RegionCounterMap)[S];
  794. auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
  795. llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
  796. Builder.getInt64(FunctionHash),
  797. Builder.getInt32(NumRegionCounters),
  798. Builder.getInt32(Counter), StepV};
  799. if (!StepV)
  800. Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
  801. makeArrayRef(Args, 4));
  802. else
  803. Builder.CreateCall(
  804. CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
  805. makeArrayRef(Args));
  806. }
  807. // This method either inserts a call to the profile run-time during
  808. // instrumentation or puts profile data into metadata for PGO use.
  809. void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
  810. llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
  811. if (!EnableValueProfiling)
  812. return;
  813. if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
  814. return;
  815. if (isa<llvm::Constant>(ValuePtr))
  816. return;
  817. bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
  818. if (InstrumentValueSites && RegionCounterMap) {
  819. auto BuilderInsertPoint = Builder.saveIP();
  820. Builder.SetInsertPoint(ValueSite);
  821. llvm::Value *Args[5] = {
  822. llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
  823. Builder.getInt64(FunctionHash),
  824. Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
  825. Builder.getInt32(ValueKind),
  826. Builder.getInt32(NumValueSites[ValueKind]++)
  827. };
  828. Builder.CreateCall(
  829. CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
  830. Builder.restoreIP(BuilderInsertPoint);
  831. return;
  832. }
  833. llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
  834. if (PGOReader && haveRegionCounts()) {
  835. // We record the top most called three functions at each call site.
  836. // Profile metadata contains "VP" string identifying this metadata
  837. // as value profiling data, then a uint32_t value for the value profiling
  838. // kind, a uint64_t value for the total number of times the call is
  839. // executed, followed by the function hash and execution count (uint64_t)
  840. // pairs for each function.
  841. if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
  842. return;
  843. llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
  844. (llvm::InstrProfValueKind)ValueKind,
  845. NumValueSites[ValueKind]);
  846. NumValueSites[ValueKind]++;
  847. }
  848. }
  849. void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
  850. bool IsInMainFile) {
  851. CGM.getPGOStats().addVisited(IsInMainFile);
  852. RegionCounts.clear();
  853. llvm::Expected<llvm::InstrProfRecord> RecordExpected =
  854. PGOReader->getInstrProfRecord(FuncName, FunctionHash);
  855. if (auto E = RecordExpected.takeError()) {
  856. auto IPE = llvm::InstrProfError::take(std::move(E));
  857. if (IPE == llvm::instrprof_error::unknown_function)
  858. CGM.getPGOStats().addMissing(IsInMainFile);
  859. else if (IPE == llvm::instrprof_error::hash_mismatch)
  860. CGM.getPGOStats().addMismatched(IsInMainFile);
  861. else if (IPE == llvm::instrprof_error::malformed)
  862. // TODO: Consider a more specific warning for this case.
  863. CGM.getPGOStats().addMismatched(IsInMainFile);
  864. return;
  865. }
  866. ProfRecord =
  867. std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
  868. RegionCounts = ProfRecord->Counts;
  869. }
  870. /// Calculate what to divide by to scale weights.
  871. ///
  872. /// Given the maximum weight, calculate a divisor that will scale all the
  873. /// weights to strictly less than UINT32_MAX.
  874. static uint64_t calculateWeightScale(uint64_t MaxWeight) {
  875. return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
  876. }
  877. /// Scale an individual branch weight (and add 1).
  878. ///
  879. /// Scale a 64-bit weight down to 32-bits using \c Scale.
  880. ///
  881. /// According to Laplace's Rule of Succession, it is better to compute the
  882. /// weight based on the count plus 1, so universally add 1 to the value.
  883. ///
  884. /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
  885. /// greater than \c Weight.
  886. static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
  887. assert(Scale && "scale by 0?");
  888. uint64_t Scaled = Weight / Scale + 1;
  889. assert(Scaled <= UINT32_MAX && "overflow 32-bits");
  890. return Scaled;
  891. }
  892. llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
  893. uint64_t FalseCount) {
  894. // Check for empty weights.
  895. if (!TrueCount && !FalseCount)
  896. return nullptr;
  897. // Calculate how to scale down to 32-bits.
  898. uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
  899. llvm::MDBuilder MDHelper(CGM.getLLVMContext());
  900. return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
  901. scaleBranchWeight(FalseCount, Scale));
  902. }
  903. llvm::MDNode *
  904. CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
  905. // We need at least two elements to create meaningful weights.
  906. if (Weights.size() < 2)
  907. return nullptr;
  908. // Check for empty weights.
  909. uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
  910. if (MaxWeight == 0)
  911. return nullptr;
  912. // Calculate how to scale down to 32-bits.
  913. uint64_t Scale = calculateWeightScale(MaxWeight);
  914. SmallVector<uint32_t, 16> ScaledWeights;
  915. ScaledWeights.reserve(Weights.size());
  916. for (uint64_t W : Weights)
  917. ScaledWeights.push_back(scaleBranchWeight(W, Scale));
  918. llvm::MDBuilder MDHelper(CGM.getLLVMContext());
  919. return MDHelper.createBranchWeights(ScaledWeights);
  920. }
  921. llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
  922. uint64_t LoopCount) {
  923. if (!PGO.haveRegionCounts())
  924. return nullptr;
  925. Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
  926. assert(CondCount.hasValue() && "missing expected loop condition count");
  927. if (*CondCount == 0)
  928. return nullptr;
  929. return createProfileWeights(LoopCount,
  930. std::max(*CondCount, LoopCount) - LoopCount);
  931. }