WholeProgramDevirt.cpp 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424
  1. //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This pass implements whole program optimization of virtual calls in cases
  11. // where we know (via !type metadata) that the list of callees is fixed. This
  12. // includes the following:
  13. // - Single implementation devirtualization: if a virtual call has a single
  14. // possible callee, replace all calls with a direct call to that callee.
  15. // - Virtual constant propagation: if the virtual function's return type is an
  16. // integer <=64 bits and all possible callees are readnone, for each class and
  17. // each list of constant arguments: evaluate the function, store the return
  18. // value alongside the virtual table, and rewrite each virtual call as a load
  19. // from the virtual table.
  20. // - Uniform return value optimization: if the conditions for virtual constant
  21. // propagation hold and each function returns the same constant value, replace
  22. // each virtual call with that constant.
  23. // - Unique return value optimization for i1 return values: if the conditions
  24. // for virtual constant propagation hold and a single vtable's function
  25. // returns 0, or a single vtable's function returns 1, replace each virtual
  26. // call with a comparison of the vptr against that vtable's address.
  27. //
  28. // This pass is intended to be used during the regular and thin LTO pipelines.
  29. // During regular LTO, the pass determines the best optimization for each
  30. // virtual call and applies the resolutions directly to virtual calls that are
  31. // eligible for virtual call optimization (i.e. calls that use either of the
  32. // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During
  33. // ThinLTO, the pass operates in two phases:
  34. // - Export phase: this is run during the thin link over a single merged module
  35. // that contains all vtables with !type metadata that participate in the link.
  36. // The pass computes a resolution for each virtual call and stores it in the
  37. // type identifier summary.
  38. // - Import phase: this is run during the thin backends over the individual
  39. // modules. The pass applies the resolutions previously computed during the
  40. // import phase to each eligible virtual call.
  41. //
  42. //===----------------------------------------------------------------------===//
  43. #include "llvm/Transforms/IPO/WholeProgramDevirt.h"
  44. #include "llvm/ADT/ArrayRef.h"
  45. #include "llvm/ADT/DenseMap.h"
  46. #include "llvm/ADT/DenseMapInfo.h"
  47. #include "llvm/ADT/DenseSet.h"
  48. #include "llvm/ADT/iterator_range.h"
  49. #include "llvm/ADT/MapVector.h"
  50. #include "llvm/ADT/SmallVector.h"
  51. #include "llvm/Analysis/AliasAnalysis.h"
  52. #include "llvm/Analysis/BasicAliasAnalysis.h"
  53. #include "llvm/Analysis/TypeMetadataUtils.h"
  54. #include "llvm/IR/CallSite.h"
  55. #include "llvm/IR/Constants.h"
  56. #include "llvm/IR/DataLayout.h"
  57. #include "llvm/IR/DebugInfoMetadata.h"
  58. #include "llvm/IR/DebugLoc.h"
  59. #include "llvm/IR/DerivedTypes.h"
  60. #include "llvm/IR/DiagnosticInfo.h"
  61. #include "llvm/IR/Function.h"
  62. #include "llvm/IR/GlobalAlias.h"
  63. #include "llvm/IR/GlobalVariable.h"
  64. #include "llvm/IR/IRBuilder.h"
  65. #include "llvm/IR/InstrTypes.h"
  66. #include "llvm/IR/Instruction.h"
  67. #include "llvm/IR/Instructions.h"
  68. #include "llvm/IR/Intrinsics.h"
  69. #include "llvm/IR/LLVMContext.h"
  70. #include "llvm/IR/Metadata.h"
  71. #include "llvm/IR/Module.h"
  72. #include "llvm/IR/ModuleSummaryIndexYAML.h"
  73. #include "llvm/Pass.h"
  74. #include "llvm/PassRegistry.h"
  75. #include "llvm/PassSupport.h"
  76. #include "llvm/Support/Casting.h"
  77. #include "llvm/Support/Error.h"
  78. #include "llvm/Support/FileSystem.h"
  79. #include "llvm/Support/MathExtras.h"
  80. #include "llvm/Transforms/IPO.h"
  81. #include "llvm/Transforms/IPO/FunctionAttrs.h"
  82. #include "llvm/Transforms/Utils/Evaluator.h"
  83. #include <algorithm>
  84. #include <cstddef>
  85. #include <map>
  86. #include <set>
  87. #include <string>
  88. using namespace llvm;
  89. using namespace wholeprogramdevirt;
  90. #define DEBUG_TYPE "wholeprogramdevirt"
  91. static cl::opt<PassSummaryAction> ClSummaryAction(
  92. "wholeprogramdevirt-summary-action",
  93. cl::desc("What to do with the summary when running this pass"),
  94. cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
  95. clEnumValN(PassSummaryAction::Import, "import",
  96. "Import typeid resolutions from summary and globals"),
  97. clEnumValN(PassSummaryAction::Export, "export",
  98. "Export typeid resolutions to summary and globals")),
  99. cl::Hidden);
  100. static cl::opt<std::string> ClReadSummary(
  101. "wholeprogramdevirt-read-summary",
  102. cl::desc("Read summary from given YAML file before running pass"),
  103. cl::Hidden);
  104. static cl::opt<std::string> ClWriteSummary(
  105. "wholeprogramdevirt-write-summary",
  106. cl::desc("Write summary to given YAML file after running pass"),
  107. cl::Hidden);
  108. // Find the minimum offset that we may store a value of size Size bits at. If
  109. // IsAfter is set, look for an offset before the object, otherwise look for an
  110. // offset after the object.
  111. uint64_t
  112. wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
  113. bool IsAfter, uint64_t Size) {
  114. // Find a minimum offset taking into account only vtable sizes.
  115. uint64_t MinByte = 0;
  116. for (const VirtualCallTarget &Target : Targets) {
  117. if (IsAfter)
  118. MinByte = std::max(MinByte, Target.minAfterBytes());
  119. else
  120. MinByte = std::max(MinByte, Target.minBeforeBytes());
  121. }
  122. // Build a vector of arrays of bytes covering, for each target, a slice of the
  123. // used region (see AccumBitVector::BytesUsed in
  124. // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
  125. // this aligns the used regions to start at MinByte.
  126. //
  127. // In this example, A, B and C are vtables, # is a byte already allocated for
  128. // a virtual function pointer, AAAA... (etc.) are the used regions for the
  129. // vtables and Offset(X) is the value computed for the Offset variable below
  130. // for X.
  131. //
  132. // Offset(A)
  133. // | |
  134. // |MinByte
  135. // A: ################AAAAAAAA|AAAAAAAA
  136. // B: ########BBBBBBBBBBBBBBBB|BBBB
  137. // C: ########################|CCCCCCCCCCCCCCCC
  138. // | Offset(B) |
  139. //
  140. // This code produces the slices of A, B and C that appear after the divider
  141. // at MinByte.
  142. std::vector<ArrayRef<uint8_t>> Used;
  143. for (const VirtualCallTarget &Target : Targets) {
  144. ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
  145. : Target.TM->Bits->Before.BytesUsed;
  146. uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
  147. : MinByte - Target.minBeforeBytes();
  148. // Disregard used regions that are smaller than Offset. These are
  149. // effectively all-free regions that do not need to be checked.
  150. if (VTUsed.size() > Offset)
  151. Used.push_back(VTUsed.slice(Offset));
  152. }
  153. if (Size == 1) {
  154. // Find a free bit in each member of Used.
  155. for (unsigned I = 0;; ++I) {
  156. uint8_t BitsUsed = 0;
  157. for (auto &&B : Used)
  158. if (I < B.size())
  159. BitsUsed |= B[I];
  160. if (BitsUsed != 0xff)
  161. return (MinByte + I) * 8 +
  162. countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined);
  163. }
  164. } else {
  165. // Find a free (Size/8) byte region in each member of Used.
  166. // FIXME: see if alignment helps.
  167. for (unsigned I = 0;; ++I) {
  168. for (auto &&B : Used) {
  169. unsigned Byte = 0;
  170. while ((I + Byte) < B.size() && Byte < (Size / 8)) {
  171. if (B[I + Byte])
  172. goto NextI;
  173. ++Byte;
  174. }
  175. }
  176. return (MinByte + I) * 8;
  177. NextI:;
  178. }
  179. }
  180. }
  181. void wholeprogramdevirt::setBeforeReturnValues(
  182. MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
  183. unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
  184. if (BitWidth == 1)
  185. OffsetByte = -(AllocBefore / 8 + 1);
  186. else
  187. OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
  188. OffsetBit = AllocBefore % 8;
  189. for (VirtualCallTarget &Target : Targets) {
  190. if (BitWidth == 1)
  191. Target.setBeforeBit(AllocBefore);
  192. else
  193. Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
  194. }
  195. }
  196. void wholeprogramdevirt::setAfterReturnValues(
  197. MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
  198. unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
  199. if (BitWidth == 1)
  200. OffsetByte = AllocAfter / 8;
  201. else
  202. OffsetByte = (AllocAfter + 7) / 8;
  203. OffsetBit = AllocAfter % 8;
  204. for (VirtualCallTarget &Target : Targets) {
  205. if (BitWidth == 1)
  206. Target.setAfterBit(AllocAfter);
  207. else
  208. Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
  209. }
  210. }
  211. VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
  212. : Fn(Fn), TM(TM),
  213. IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {}
  214. namespace {
  215. // A slot in a set of virtual tables. The TypeID identifies the set of virtual
  216. // tables, and the ByteOffset is the offset in bytes from the address point to
  217. // the virtual function pointer.
  218. struct VTableSlot {
  219. Metadata *TypeID;
  220. uint64_t ByteOffset;
  221. };
  222. } // end anonymous namespace
  223. namespace llvm {
  224. template <> struct DenseMapInfo<VTableSlot> {
  225. static VTableSlot getEmptyKey() {
  226. return {DenseMapInfo<Metadata *>::getEmptyKey(),
  227. DenseMapInfo<uint64_t>::getEmptyKey()};
  228. }
  229. static VTableSlot getTombstoneKey() {
  230. return {DenseMapInfo<Metadata *>::getTombstoneKey(),
  231. DenseMapInfo<uint64_t>::getTombstoneKey()};
  232. }
  233. static unsigned getHashValue(const VTableSlot &I) {
  234. return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
  235. DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
  236. }
  237. static bool isEqual(const VTableSlot &LHS,
  238. const VTableSlot &RHS) {
  239. return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
  240. }
  241. };
  242. } // end namespace llvm
  243. namespace {
  244. // A virtual call site. VTable is the loaded virtual table pointer, and CS is
  245. // the indirect virtual call.
  246. struct VirtualCallSite {
  247. Value *VTable;
  248. CallSite CS;
  249. // If non-null, this field points to the associated unsafe use count stored in
  250. // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
  251. // of that field for details.
  252. unsigned *NumUnsafeUses;
  253. void emitRemark(const Twine &OptName, const Twine &TargetName) {
  254. Function *F = CS.getCaller();
  255. emitOptimizationRemark(
  256. F->getContext(), DEBUG_TYPE, *F,
  257. CS.getInstruction()->getDebugLoc(),
  258. OptName + ": devirtualized a call to " + TargetName);
  259. }
  260. void replaceAndErase(const Twine &OptName, const Twine &TargetName,
  261. bool RemarksEnabled, Value *New) {
  262. if (RemarksEnabled)
  263. emitRemark(OptName, TargetName);
  264. CS->replaceAllUsesWith(New);
  265. if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
  266. BranchInst::Create(II->getNormalDest(), CS.getInstruction());
  267. II->getUnwindDest()->removePredecessor(II->getParent());
  268. }
  269. CS->eraseFromParent();
  270. // This use is no longer unsafe.
  271. if (NumUnsafeUses)
  272. --*NumUnsafeUses;
  273. }
  274. };
  275. // Call site information collected for a specific VTableSlot and possibly a list
  276. // of constant integer arguments. The grouping by arguments is handled by the
  277. // VTableSlotInfo class.
  278. struct CallSiteInfo {
  279. /// The set of call sites for this slot. Used during regular LTO and the
  280. /// import phase of ThinLTO (as well as the export phase of ThinLTO for any
  281. /// call sites that appear in the merged module itself); in each of these
  282. /// cases we are directly operating on the call sites at the IR level.
  283. std::vector<VirtualCallSite> CallSites;
  284. // These fields are used during the export phase of ThinLTO and reflect
  285. // information collected from function summaries.
  286. /// Whether any function summary contains an llvm.assume(llvm.type.test) for
  287. /// this slot.
  288. bool SummaryHasTypeTestAssumeUsers;
  289. /// CFI-specific: a vector containing the list of function summaries that use
  290. /// the llvm.type.checked.load intrinsic and therefore will require
  291. /// resolutions for llvm.type.test in order to implement CFI checks if
  292. /// devirtualization was unsuccessful. If devirtualization was successful, the
  293. /// pass will clear this vector by calling markDevirt(). If at the end of the
  294. /// pass the vector is non-empty, we will need to add a use of llvm.type.test
  295. /// to each of the function summaries in the vector.
  296. std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
  297. bool isExported() const {
  298. return SummaryHasTypeTestAssumeUsers ||
  299. !SummaryTypeCheckedLoadUsers.empty();
  300. }
  301. /// As explained in the comment for SummaryTypeCheckedLoadUsers.
  302. void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); }
  303. };
  304. // Call site information collected for a specific VTableSlot.
  305. struct VTableSlotInfo {
  306. // The set of call sites which do not have all constant integer arguments
  307. // (excluding "this").
  308. CallSiteInfo CSInfo;
  309. // The set of call sites with all constant integer arguments (excluding
  310. // "this"), grouped by argument list.
  311. std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
  312. void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
  313. private:
  314. CallSiteInfo &findCallSiteInfo(CallSite CS);
  315. };
  316. CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
  317. std::vector<uint64_t> Args;
  318. auto *CI = dyn_cast<IntegerType>(CS.getType());
  319. if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
  320. return CSInfo;
  321. for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
  322. auto *CI = dyn_cast<ConstantInt>(Arg);
  323. if (!CI || CI->getBitWidth() > 64)
  324. return CSInfo;
  325. Args.push_back(CI->getZExtValue());
  326. }
  327. return ConstCSInfo[Args];
  328. }
  329. void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
  330. unsigned *NumUnsafeUses) {
  331. findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
  332. }
  333. struct DevirtModule {
  334. Module &M;
  335. function_ref<AAResults &(Function &)> AARGetter;
  336. ModuleSummaryIndex *ExportSummary;
  337. const ModuleSummaryIndex *ImportSummary;
  338. IntegerType *Int8Ty;
  339. PointerType *Int8PtrTy;
  340. IntegerType *Int32Ty;
  341. IntegerType *Int64Ty;
  342. IntegerType *IntPtrTy;
  343. bool RemarksEnabled;
  344. MapVector<VTableSlot, VTableSlotInfo> CallSlots;
  345. // This map keeps track of the number of "unsafe" uses of a loaded function
  346. // pointer. The key is the associated llvm.type.test intrinsic call generated
  347. // by this pass. An unsafe use is one that calls the loaded function pointer
  348. // directly. Every time we eliminate an unsafe use (for example, by
  349. // devirtualizing it or by applying virtual constant propagation), we
  350. // decrement the value stored in this map. If a value reaches zero, we can
  351. // eliminate the type check by RAUWing the associated llvm.type.test call with
  352. // true.
  353. std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
  354. DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
  355. ModuleSummaryIndex *ExportSummary,
  356. const ModuleSummaryIndex *ImportSummary)
  357. : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary),
  358. ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())),
  359. Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
  360. Int32Ty(Type::getInt32Ty(M.getContext())),
  361. Int64Ty(Type::getInt64Ty(M.getContext())),
  362. IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
  363. RemarksEnabled(areRemarksEnabled()) {
  364. assert(!(ExportSummary && ImportSummary));
  365. }
  366. bool areRemarksEnabled();
  367. void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
  368. void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
  369. void buildTypeIdentifierMap(
  370. std::vector<VTableBits> &Bits,
  371. DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
  372. Constant *getPointerAtOffset(Constant *I, uint64_t Offset);
  373. bool
  374. tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
  375. const std::set<TypeMemberInfo> &TypeMemberInfos,
  376. uint64_t ByteOffset);
  377. void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
  378. bool &IsExported);
  379. bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
  380. VTableSlotInfo &SlotInfo,
  381. WholeProgramDevirtResolution *Res);
  382. bool tryEvaluateFunctionsWithArgs(
  383. MutableArrayRef<VirtualCallTarget> TargetsForSlot,
  384. ArrayRef<uint64_t> Args);
  385. void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
  386. uint64_t TheRetVal);
  387. bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
  388. CallSiteInfo &CSInfo,
  389. WholeProgramDevirtResolution::ByArg *Res);
  390. // Returns the global symbol name that is used to export information about the
  391. // given vtable slot and list of arguments.
  392. std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args,
  393. StringRef Name);
  394. // This function is called during the export phase to create a symbol
  395. // definition containing information about the given vtable slot and list of
  396. // arguments.
  397. void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
  398. Constant *C);
  399. // This function is called during the import phase to create a reference to
  400. // the symbol definition created during the export phase.
  401. Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
  402. StringRef Name, unsigned AbsWidth = 0);
  403. void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
  404. Constant *UniqueMemberAddr);
  405. bool tryUniqueRetValOpt(unsigned BitWidth,
  406. MutableArrayRef<VirtualCallTarget> TargetsForSlot,
  407. CallSiteInfo &CSInfo,
  408. WholeProgramDevirtResolution::ByArg *Res,
  409. VTableSlot Slot, ArrayRef<uint64_t> Args);
  410. void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
  411. Constant *Byte, Constant *Bit);
  412. bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
  413. VTableSlotInfo &SlotInfo,
  414. WholeProgramDevirtResolution *Res, VTableSlot Slot);
  415. void rebuildGlobal(VTableBits &B);
  416. // Apply the summary resolution for Slot to all virtual calls in SlotInfo.
  417. void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo);
  418. // If we were able to eliminate all unsafe uses for a type checked load,
  419. // eliminate the associated type tests by replacing them with true.
  420. void removeRedundantTypeTests();
  421. bool run();
  422. // Lower the module using the action and summary passed as command line
  423. // arguments. For testing purposes only.
  424. static bool runForTesting(Module &M,
  425. function_ref<AAResults &(Function &)> AARGetter);
  426. };
  427. struct WholeProgramDevirt : public ModulePass {
  428. static char ID;
  429. bool UseCommandLine = false;
  430. ModuleSummaryIndex *ExportSummary;
  431. const ModuleSummaryIndex *ImportSummary;
  432. WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) {
  433. initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
  434. }
  435. WholeProgramDevirt(ModuleSummaryIndex *ExportSummary,
  436. const ModuleSummaryIndex *ImportSummary)
  437. : ModulePass(ID), ExportSummary(ExportSummary),
  438. ImportSummary(ImportSummary) {
  439. initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
  440. }
  441. bool runOnModule(Module &M) override {
  442. if (skipModule(M))
  443. return false;
  444. if (UseCommandLine)
  445. return DevirtModule::runForTesting(M, LegacyAARGetter(*this));
  446. return DevirtModule(M, LegacyAARGetter(*this), ExportSummary, ImportSummary)
  447. .run();
  448. }
  449. void getAnalysisUsage(AnalysisUsage &AU) const override {
  450. AU.addRequired<AssumptionCacheTracker>();
  451. AU.addRequired<TargetLibraryInfoWrapperPass>();
  452. }
  453. };
  454. } // end anonymous namespace
  455. INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt",
  456. "Whole program devirtualization", false, false)
  457. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  458. INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
  459. INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
  460. "Whole program devirtualization", false, false)
  461. char WholeProgramDevirt::ID = 0;
  462. ModulePass *
  463. llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary,
  464. const ModuleSummaryIndex *ImportSummary) {
  465. return new WholeProgramDevirt(ExportSummary, ImportSummary);
  466. }
  467. PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
  468. ModuleAnalysisManager &AM) {
  469. auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
  470. auto AARGetter = [&](Function &F) -> AAResults & {
  471. return FAM.getResult<AAManager>(F);
  472. };
  473. if (!DevirtModule(M, AARGetter, nullptr, nullptr).run())
  474. return PreservedAnalyses::all();
  475. return PreservedAnalyses::none();
  476. }
  477. bool DevirtModule::runForTesting(
  478. Module &M, function_ref<AAResults &(Function &)> AARGetter) {
  479. ModuleSummaryIndex Summary;
  480. // Handle the command-line summary arguments. This code is for testing
  481. // purposes only, so we handle errors directly.
  482. if (!ClReadSummary.empty()) {
  483. ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
  484. ": ");
  485. auto ReadSummaryFile =
  486. ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
  487. yaml::Input In(ReadSummaryFile->getBuffer());
  488. In >> Summary;
  489. ExitOnErr(errorCodeToError(In.error()));
  490. }
  491. bool Changed =
  492. DevirtModule(
  493. M, AARGetter,
  494. ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
  495. ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
  496. .run();
  497. if (!ClWriteSummary.empty()) {
  498. ExitOnError ExitOnErr(
  499. "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
  500. std::error_code EC;
  501. raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
  502. ExitOnErr(errorCodeToError(EC));
  503. yaml::Output Out(OS);
  504. Out << Summary;
  505. }
  506. return Changed;
  507. }
  508. void DevirtModule::buildTypeIdentifierMap(
  509. std::vector<VTableBits> &Bits,
  510. DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
  511. DenseMap<GlobalVariable *, VTableBits *> GVToBits;
  512. Bits.reserve(M.getGlobalList().size());
  513. SmallVector<MDNode *, 2> Types;
  514. for (GlobalVariable &GV : M.globals()) {
  515. Types.clear();
  516. GV.getMetadata(LLVMContext::MD_type, Types);
  517. if (Types.empty())
  518. continue;
  519. VTableBits *&BitsPtr = GVToBits[&GV];
  520. if (!BitsPtr) {
  521. Bits.emplace_back();
  522. Bits.back().GV = &GV;
  523. Bits.back().ObjectSize =
  524. M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
  525. BitsPtr = &Bits.back();
  526. }
  527. for (MDNode *Type : Types) {
  528. auto TypeID = Type->getOperand(1).get();
  529. uint64_t Offset =
  530. cast<ConstantInt>(
  531. cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
  532. ->getZExtValue();
  533. TypeIdMap[TypeID].insert({BitsPtr, Offset});
  534. }
  535. }
  536. }
  537. Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) {
  538. if (I->getType()->isPointerTy()) {
  539. if (Offset == 0)
  540. return I;
  541. return nullptr;
  542. }
  543. const DataLayout &DL = M.getDataLayout();
  544. if (auto *C = dyn_cast<ConstantStruct>(I)) {
  545. const StructLayout *SL = DL.getStructLayout(C->getType());
  546. if (Offset >= SL->getSizeInBytes())
  547. return nullptr;
  548. unsigned Op = SL->getElementContainingOffset(Offset);
  549. return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
  550. Offset - SL->getElementOffset(Op));
  551. }
  552. if (auto *C = dyn_cast<ConstantArray>(I)) {
  553. ArrayType *VTableTy = C->getType();
  554. uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
  555. unsigned Op = Offset / ElemSize;
  556. if (Op >= C->getNumOperands())
  557. return nullptr;
  558. return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
  559. Offset % ElemSize);
  560. }
  561. return nullptr;
  562. }
  563. bool DevirtModule::tryFindVirtualCallTargets(
  564. std::vector<VirtualCallTarget> &TargetsForSlot,
  565. const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
  566. for (const TypeMemberInfo &TM : TypeMemberInfos) {
  567. if (!TM.Bits->GV->isConstant())
  568. return false;
  569. Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
  570. TM.Offset + ByteOffset);
  571. if (!Ptr)
  572. return false;
  573. auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts());
  574. if (!Fn)
  575. return false;
  576. // We can disregard __cxa_pure_virtual as a possible call target, as
  577. // calls to pure virtuals are UB.
  578. if (Fn->getName() == "__cxa_pure_virtual")
  579. continue;
  580. TargetsForSlot.push_back({Fn, &TM});
  581. }
  582. // Give up if we couldn't find any targets.
  583. return !TargetsForSlot.empty();
  584. }
  585. void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
  586. Constant *TheFn, bool &IsExported) {
  587. auto Apply = [&](CallSiteInfo &CSInfo) {
  588. for (auto &&VCallSite : CSInfo.CallSites) {
  589. if (RemarksEnabled)
  590. VCallSite.emitRemark("single-impl", TheFn->getName());
  591. VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
  592. TheFn, VCallSite.CS.getCalledValue()->getType()));
  593. // This use is no longer unsafe.
  594. if (VCallSite.NumUnsafeUses)
  595. --*VCallSite.NumUnsafeUses;
  596. }
  597. if (CSInfo.isExported()) {
  598. IsExported = true;
  599. CSInfo.markDevirt();
  600. }
  601. };
  602. Apply(SlotInfo.CSInfo);
  603. for (auto &P : SlotInfo.ConstCSInfo)
  604. Apply(P.second);
  605. }
  606. bool DevirtModule::trySingleImplDevirt(
  607. MutableArrayRef<VirtualCallTarget> TargetsForSlot,
  608. VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) {
  609. // See if the program contains a single implementation of this virtual
  610. // function.
  611. Function *TheFn = TargetsForSlot[0].Fn;
  612. for (auto &&Target : TargetsForSlot)
  613. if (TheFn != Target.Fn)
  614. return false;
  615. // If so, update each call site to call that implementation directly.
  616. if (RemarksEnabled)
  617. TargetsForSlot[0].WasDevirt = true;
  618. bool IsExported = false;
  619. applySingleImplDevirt(SlotInfo, TheFn, IsExported);
  620. if (!IsExported)
  621. return false;
  622. // If the only implementation has local linkage, we must promote to external
  623. // to make it visible to thin LTO objects. We can only get here during the
  624. // ThinLTO export phase.
  625. if (TheFn->hasLocalLinkage()) {
  626. TheFn->setLinkage(GlobalValue::ExternalLinkage);
  627. TheFn->setVisibility(GlobalValue::HiddenVisibility);
  628. TheFn->setName(TheFn->getName() + "$merged");
  629. }
  630. Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
  631. Res->SingleImplName = TheFn->getName();
  632. return true;
  633. }
  634. bool DevirtModule::tryEvaluateFunctionsWithArgs(
  635. MutableArrayRef<VirtualCallTarget> TargetsForSlot,
  636. ArrayRef<uint64_t> Args) {
  637. // Evaluate each function and store the result in each target's RetVal
  638. // field.
  639. for (VirtualCallTarget &Target : TargetsForSlot) {
  640. if (Target.Fn->arg_size() != Args.size() + 1)
  641. return false;
  642. Evaluator Eval(M.getDataLayout(), nullptr);
  643. SmallVector<Constant *, 2> EvalArgs;
  644. EvalArgs.push_back(
  645. Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
  646. for (unsigned I = 0; I != Args.size(); ++I) {
  647. auto *ArgTy = dyn_cast<IntegerType>(
  648. Target.Fn->getFunctionType()->getParamType(I + 1));
  649. if (!ArgTy)
  650. return false;
  651. EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
  652. }
  653. Constant *RetVal;
  654. if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
  655. !isa<ConstantInt>(RetVal))
  656. return false;
  657. Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
  658. }
  659. return true;
  660. }
  661. void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
  662. uint64_t TheRetVal) {
  663. for (auto Call : CSInfo.CallSites)
  664. Call.replaceAndErase(
  665. "uniform-ret-val", FnName, RemarksEnabled,
  666. ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
  667. CSInfo.markDevirt();
  668. }
  669. bool DevirtModule::tryUniformRetValOpt(
  670. MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo,
  671. WholeProgramDevirtResolution::ByArg *Res) {
  672. // Uniform return value optimization. If all functions return the same
  673. // constant, replace all calls with that constant.
  674. uint64_t TheRetVal = TargetsForSlot[0].RetVal;
  675. for (const VirtualCallTarget &Target : TargetsForSlot)
  676. if (Target.RetVal != TheRetVal)
  677. return false;
  678. if (CSInfo.isExported()) {
  679. Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal;
  680. Res->Info = TheRetVal;
  681. }
  682. applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
  683. if (RemarksEnabled)
  684. for (auto &&Target : TargetsForSlot)
  685. Target.WasDevirt = true;
  686. return true;
  687. }
  688. std::string DevirtModule::getGlobalName(VTableSlot Slot,
  689. ArrayRef<uint64_t> Args,
  690. StringRef Name) {
  691. std::string FullName = "__typeid_";
  692. raw_string_ostream OS(FullName);
  693. OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset;
  694. for (uint64_t Arg : Args)
  695. OS << '_' << Arg;
  696. OS << '_' << Name;
  697. return OS.str();
  698. }
  699. void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
  700. StringRef Name, Constant *C) {
  701. GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
  702. getGlobalName(Slot, Args, Name), C, &M);
  703. GA->setVisibility(GlobalValue::HiddenVisibility);
  704. }
  705. Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
  706. StringRef Name, unsigned AbsWidth) {
  707. Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty);
  708. auto *GV = dyn_cast<GlobalVariable>(C);
  709. // We only need to set metadata if the global is newly created, in which
  710. // case it would not have hidden visibility.
  711. if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility)
  712. return C;
  713. GV->setVisibility(GlobalValue::HiddenVisibility);
  714. auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
  715. auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
  716. auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
  717. GV->setMetadata(LLVMContext::MD_absolute_symbol,
  718. MDNode::get(M.getContext(), {MinC, MaxC}));
  719. };
  720. if (AbsWidth == IntPtrTy->getBitWidth())
  721. SetAbsRange(~0ull, ~0ull); // Full set.
  722. else if (AbsWidth)
  723. SetAbsRange(0, 1ull << AbsWidth);
  724. return GV;
  725. }
  726. void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
  727. bool IsOne,
  728. Constant *UniqueMemberAddr) {
  729. for (auto &&Call : CSInfo.CallSites) {
  730. IRBuilder<> B(Call.CS.getInstruction());
  731. Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
  732. Call.VTable, UniqueMemberAddr);
  733. Cmp = B.CreateZExt(Cmp, Call.CS->getType());
  734. Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp);
  735. }
  736. CSInfo.markDevirt();
  737. }
  738. bool DevirtModule::tryUniqueRetValOpt(
  739. unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
  740. CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
  741. VTableSlot Slot, ArrayRef<uint64_t> Args) {
  742. // IsOne controls whether we look for a 0 or a 1.
  743. auto tryUniqueRetValOptFor = [&](bool IsOne) {
  744. const TypeMemberInfo *UniqueMember = nullptr;
  745. for (const VirtualCallTarget &Target : TargetsForSlot) {
  746. if (Target.RetVal == (IsOne ? 1 : 0)) {
  747. if (UniqueMember)
  748. return false;
  749. UniqueMember = Target.TM;
  750. }
  751. }
  752. // We should have found a unique member or bailed out by now. We already
  753. // checked for a uniform return value in tryUniformRetValOpt.
  754. assert(UniqueMember);
  755. Constant *UniqueMemberAddr =
  756. ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy);
  757. UniqueMemberAddr = ConstantExpr::getGetElementPtr(
  758. Int8Ty, UniqueMemberAddr,
  759. ConstantInt::get(Int64Ty, UniqueMember->Offset));
  760. if (CSInfo.isExported()) {
  761. Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
  762. Res->Info = IsOne;
  763. exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr);
  764. }
  765. // Replace each call with the comparison.
  766. applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
  767. UniqueMemberAddr);
  768. // Update devirtualization statistics for targets.
  769. if (RemarksEnabled)
  770. for (auto &&Target : TargetsForSlot)
  771. Target.WasDevirt = true;
  772. return true;
  773. };
  774. if (BitWidth == 1) {
  775. if (tryUniqueRetValOptFor(true))
  776. return true;
  777. if (tryUniqueRetValOptFor(false))
  778. return true;
  779. }
  780. return false;
  781. }
  782. void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
  783. Constant *Byte, Constant *Bit) {
  784. for (auto Call : CSInfo.CallSites) {
  785. auto *RetType = cast<IntegerType>(Call.CS.getType());
  786. IRBuilder<> B(Call.CS.getInstruction());
  787. Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte);
  788. if (RetType->getBitWidth() == 1) {
  789. Value *Bits = B.CreateLoad(Addr);
  790. Value *BitsAndBit = B.CreateAnd(Bits, Bit);
  791. auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
  792. Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
  793. IsBitSet);
  794. } else {
  795. Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
  796. Value *Val = B.CreateLoad(RetType, ValAddr);
  797. Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val);
  798. }
  799. }
  800. CSInfo.markDevirt();
  801. }
  802. bool DevirtModule::tryVirtualConstProp(
  803. MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
  804. WholeProgramDevirtResolution *Res, VTableSlot Slot) {
  805. // This only works if the function returns an integer.
  806. auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
  807. if (!RetType)
  808. return false;
  809. unsigned BitWidth = RetType->getBitWidth();
  810. if (BitWidth > 64)
  811. return false;
  812. // Make sure that each function is defined, does not access memory, takes at
  813. // least one argument, does not use its first argument (which we assume is
  814. // 'this'), and has the same return type.
  815. //
  816. // Note that we test whether this copy of the function is readnone, rather
  817. // than testing function attributes, which must hold for any copy of the
  818. // function, even a less optimized version substituted at link time. This is
  819. // sound because the virtual constant propagation optimizations effectively
  820. // inline all implementations of the virtual function into each call site,
  821. // rather than using function attributes to perform local optimization.
  822. for (VirtualCallTarget &Target : TargetsForSlot) {
  823. if (Target.Fn->isDeclaration() ||
  824. computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) !=
  825. MAK_ReadNone ||
  826. Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() ||
  827. Target.Fn->getReturnType() != RetType)
  828. return false;
  829. }
  830. for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
  831. if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
  832. continue;
  833. WholeProgramDevirtResolution::ByArg *ResByArg = nullptr;
  834. if (Res)
  835. ResByArg = &Res->ResByArg[CSByConstantArg.first];
  836. if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg))
  837. continue;
  838. if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second,
  839. ResByArg, Slot, CSByConstantArg.first))
  840. continue;
  841. // Find an allocation offset in bits in all vtables associated with the
  842. // type.
  843. uint64_t AllocBefore =
  844. findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
  845. uint64_t AllocAfter =
  846. findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
  847. // Calculate the total amount of padding needed to store a value at both
  848. // ends of the object.
  849. uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
  850. for (auto &&Target : TargetsForSlot) {
  851. TotalPaddingBefore += std::max<int64_t>(
  852. (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
  853. TotalPaddingAfter += std::max<int64_t>(
  854. (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
  855. }
  856. // If the amount of padding is too large, give up.
  857. // FIXME: do something smarter here.
  858. if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
  859. continue;
  860. // Calculate the offset to the value as a (possibly negative) byte offset
  861. // and (if applicable) a bit offset, and store the values in the targets.
  862. int64_t OffsetByte;
  863. uint64_t OffsetBit;
  864. if (TotalPaddingBefore <= TotalPaddingAfter)
  865. setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
  866. OffsetBit);
  867. else
  868. setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
  869. OffsetBit);
  870. if (RemarksEnabled)
  871. for (auto &&Target : TargetsForSlot)
  872. Target.WasDevirt = true;
  873. Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
  874. Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
  875. if (CSByConstantArg.second.isExported()) {
  876. ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp;
  877. exportGlobal(Slot, CSByConstantArg.first, "byte",
  878. ConstantExpr::getIntToPtr(ByteConst, Int8PtrTy));
  879. exportGlobal(Slot, CSByConstantArg.first, "bit",
  880. ConstantExpr::getIntToPtr(BitConst, Int8PtrTy));
  881. }
  882. // Rewrite each call to a load from OffsetByte/OffsetBit.
  883. applyVirtualConstProp(CSByConstantArg.second,
  884. TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
  885. }
  886. return true;
  887. }
  888. void DevirtModule::rebuildGlobal(VTableBits &B) {
  889. if (B.Before.Bytes.empty() && B.After.Bytes.empty())
  890. return;
  891. // Align each byte array to pointer width.
  892. unsigned PointerSize = M.getDataLayout().getPointerSize();
  893. B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize));
  894. B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize));
  895. // Before was stored in reverse order; flip it now.
  896. for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
  897. std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
  898. // Build an anonymous global containing the before bytes, followed by the
  899. // original initializer, followed by the after bytes.
  900. auto NewInit = ConstantStruct::getAnon(
  901. {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
  902. B.GV->getInitializer(),
  903. ConstantDataArray::get(M.getContext(), B.After.Bytes)});
  904. auto NewGV =
  905. new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
  906. GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
  907. NewGV->setSection(B.GV->getSection());
  908. NewGV->setComdat(B.GV->getComdat());
  909. // Copy the original vtable's metadata to the anonymous global, adjusting
  910. // offsets as required.
  911. NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
  912. // Build an alias named after the original global, pointing at the second
  913. // element (the original initializer).
  914. auto Alias = GlobalAlias::create(
  915. B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
  916. ConstantExpr::getGetElementPtr(
  917. NewInit->getType(), NewGV,
  918. ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
  919. ConstantInt::get(Int32Ty, 1)}),
  920. &M);
  921. Alias->setVisibility(B.GV->getVisibility());
  922. Alias->takeName(B.GV);
  923. B.GV->replaceAllUsesWith(Alias);
  924. B.GV->eraseFromParent();
  925. }
  926. bool DevirtModule::areRemarksEnabled() {
  927. const auto &FL = M.getFunctionList();
  928. if (FL.empty())
  929. return false;
  930. const Function &Fn = FL.front();
  931. const auto &BBL = Fn.getBasicBlockList();
  932. if (BBL.empty())
  933. return false;
  934. auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front());
  935. return DI.isEnabled();
  936. }
  937. void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
  938. Function *AssumeFunc) {
  939. // Find all virtual calls via a virtual table pointer %p under an assumption
  940. // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
  941. // points to a member of the type identifier %md. Group calls by (type ID,
  942. // offset) pair (effectively the identity of the virtual function) and store
  943. // to CallSlots.
  944. DenseSet<Value *> SeenPtrs;
  945. for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
  946. I != E;) {
  947. auto CI = dyn_cast<CallInst>(I->getUser());
  948. ++I;
  949. if (!CI)
  950. continue;
  951. // Search for virtual calls based on %p and add them to DevirtCalls.
  952. SmallVector<DevirtCallSite, 1> DevirtCalls;
  953. SmallVector<CallInst *, 1> Assumes;
  954. findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
  955. // If we found any, add them to CallSlots. Only do this if we haven't seen
  956. // the vtable pointer before, as it may have been CSE'd with pointers from
  957. // other call sites, and we don't want to process call sites multiple times.
  958. if (!Assumes.empty()) {
  959. Metadata *TypeId =
  960. cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
  961. Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
  962. if (SeenPtrs.insert(Ptr).second) {
  963. for (DevirtCallSite Call : DevirtCalls) {
  964. CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0),
  965. Call.CS, nullptr);
  966. }
  967. }
  968. }
  969. // We no longer need the assumes or the type test.
  970. for (auto Assume : Assumes)
  971. Assume->eraseFromParent();
  972. // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
  973. // may use the vtable argument later.
  974. if (CI->use_empty())
  975. CI->eraseFromParent();
  976. }
  977. }
  978. void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
  979. Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
  980. for (auto I = TypeCheckedLoadFunc->use_begin(),
  981. E = TypeCheckedLoadFunc->use_end();
  982. I != E;) {
  983. auto CI = dyn_cast<CallInst>(I->getUser());
  984. ++I;
  985. if (!CI)
  986. continue;
  987. Value *Ptr = CI->getArgOperand(0);
  988. Value *Offset = CI->getArgOperand(1);
  989. Value *TypeIdValue = CI->getArgOperand(2);
  990. Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
  991. SmallVector<DevirtCallSite, 1> DevirtCalls;
  992. SmallVector<Instruction *, 1> LoadedPtrs;
  993. SmallVector<Instruction *, 1> Preds;
  994. bool HasNonCallUses = false;
  995. findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
  996. HasNonCallUses, CI);
  997. // Start by generating "pessimistic" code that explicitly loads the function
  998. // pointer from the vtable and performs the type check. If possible, we will
  999. // eliminate the load and the type check later.
  1000. // If possible, only generate the load at the point where it is used.
  1001. // This helps avoid unnecessary spills.
  1002. IRBuilder<> LoadB(
  1003. (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
  1004. Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
  1005. Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
  1006. Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
  1007. for (Instruction *LoadedPtr : LoadedPtrs) {
  1008. LoadedPtr->replaceAllUsesWith(LoadedValue);
  1009. LoadedPtr->eraseFromParent();
  1010. }
  1011. // Likewise for the type test.
  1012. IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
  1013. CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
  1014. for (Instruction *Pred : Preds) {
  1015. Pred->replaceAllUsesWith(TypeTestCall);
  1016. Pred->eraseFromParent();
  1017. }
  1018. // We have already erased any extractvalue instructions that refer to the
  1019. // intrinsic call, but the intrinsic may have other non-extractvalue uses
  1020. // (although this is unlikely). In that case, explicitly build a pair and
  1021. // RAUW it.
  1022. if (!CI->use_empty()) {
  1023. Value *Pair = UndefValue::get(CI->getType());
  1024. IRBuilder<> B(CI);
  1025. Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
  1026. Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
  1027. CI->replaceAllUsesWith(Pair);
  1028. }
  1029. // The number of unsafe uses is initially the number of uses.
  1030. auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
  1031. NumUnsafeUses = DevirtCalls.size();
  1032. // If the function pointer has a non-call user, we cannot eliminate the type
  1033. // check, as one of those users may eventually call the pointer. Increment
  1034. // the unsafe use count to make sure it cannot reach zero.
  1035. if (HasNonCallUses)
  1036. ++NumUnsafeUses;
  1037. for (DevirtCallSite Call : DevirtCalls) {
  1038. CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
  1039. &NumUnsafeUses);
  1040. }
  1041. CI->eraseFromParent();
  1042. }
  1043. }
  1044. void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
  1045. const TypeIdSummary *TidSummary =
  1046. ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString());
  1047. if (!TidSummary)
  1048. return;
  1049. auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset);
  1050. if (ResI == TidSummary->WPDRes.end())
  1051. return;
  1052. const WholeProgramDevirtResolution &Res = ResI->second;
  1053. if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
  1054. // The type of the function in the declaration is irrelevant because every
  1055. // call site will cast it to the correct type.
  1056. auto *SingleImpl = M.getOrInsertFunction(
  1057. Res.SingleImplName, Type::getVoidTy(M.getContext()));
  1058. // This is the import phase so we should not be exporting anything.
  1059. bool IsExported = false;
  1060. applySingleImplDevirt(SlotInfo, SingleImpl, IsExported);
  1061. assert(!IsExported);
  1062. }
  1063. for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) {
  1064. auto I = Res.ResByArg.find(CSByConstantArg.first);
  1065. if (I == Res.ResByArg.end())
  1066. continue;
  1067. auto &ResByArg = I->second;
  1068. // FIXME: We should figure out what to do about the "function name" argument
  1069. // to the apply* functions, as the function names are unavailable during the
  1070. // importing phase. For now we just pass the empty string. This does not
  1071. // impact correctness because the function names are just used for remarks.
  1072. switch (ResByArg.TheKind) {
  1073. case WholeProgramDevirtResolution::ByArg::UniformRetVal:
  1074. applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info);
  1075. break;
  1076. case WholeProgramDevirtResolution::ByArg::UniqueRetVal: {
  1077. Constant *UniqueMemberAddr =
  1078. importGlobal(Slot, CSByConstantArg.first, "unique_member");
  1079. applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info,
  1080. UniqueMemberAddr);
  1081. break;
  1082. }
  1083. case WholeProgramDevirtResolution::ByArg::VirtualConstProp: {
  1084. Constant *Byte = importGlobal(Slot, CSByConstantArg.first, "byte", 32);
  1085. Byte = ConstantExpr::getPtrToInt(Byte, Int32Ty);
  1086. Constant *Bit = importGlobal(Slot, CSByConstantArg.first, "bit", 8);
  1087. Bit = ConstantExpr::getPtrToInt(Bit, Int8Ty);
  1088. applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit);
  1089. }
  1090. default:
  1091. break;
  1092. }
  1093. }
  1094. }
  1095. void DevirtModule::removeRedundantTypeTests() {
  1096. auto True = ConstantInt::getTrue(M.getContext());
  1097. for (auto &&U : NumUnsafeUsesForTypeTest) {
  1098. if (U.second == 0) {
  1099. U.first->replaceAllUsesWith(True);
  1100. U.first->eraseFromParent();
  1101. }
  1102. }
  1103. }
  1104. bool DevirtModule::run() {
  1105. Function *TypeTestFunc =
  1106. M.getFunction(Intrinsic::getName(Intrinsic::type_test));
  1107. Function *TypeCheckedLoadFunc =
  1108. M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
  1109. Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
  1110. // Normally if there are no users of the devirtualization intrinsics in the
  1111. // module, this pass has nothing to do. But if we are exporting, we also need
  1112. // to handle any users that appear only in the function summaries.
  1113. if (!ExportSummary &&
  1114. (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
  1115. AssumeFunc->use_empty()) &&
  1116. (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
  1117. return false;
  1118. if (TypeTestFunc && AssumeFunc)
  1119. scanTypeTestUsers(TypeTestFunc, AssumeFunc);
  1120. if (TypeCheckedLoadFunc)
  1121. scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
  1122. if (ImportSummary) {
  1123. for (auto &S : CallSlots)
  1124. importResolution(S.first, S.second);
  1125. removeRedundantTypeTests();
  1126. // The rest of the code is only necessary when exporting or during regular
  1127. // LTO, so we are done.
  1128. return true;
  1129. }
  1130. // Rebuild type metadata into a map for easy lookup.
  1131. std::vector<VTableBits> Bits;
  1132. DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
  1133. buildTypeIdentifierMap(Bits, TypeIdMap);
  1134. if (TypeIdMap.empty())
  1135. return true;
  1136. // Collect information from summary about which calls to try to devirtualize.
  1137. if (ExportSummary) {
  1138. DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
  1139. for (auto &P : TypeIdMap) {
  1140. if (auto *TypeId = dyn_cast<MDString>(P.first))
  1141. MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
  1142. TypeId);
  1143. }
  1144. for (auto &P : *ExportSummary) {
  1145. for (auto &S : P.second) {
  1146. auto *FS = dyn_cast<FunctionSummary>(S.get());
  1147. if (!FS)
  1148. continue;
  1149. // FIXME: Only add live functions.
  1150. for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
  1151. for (Metadata *MD : MetadataByGUID[VF.GUID]) {
  1152. CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers =
  1153. true;
  1154. }
  1155. }
  1156. for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
  1157. for (Metadata *MD : MetadataByGUID[VF.GUID]) {
  1158. CallSlots[{MD, VF.Offset}]
  1159. .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS);
  1160. }
  1161. }
  1162. for (const FunctionSummary::ConstVCall &VC :
  1163. FS->type_test_assume_const_vcalls()) {
  1164. for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
  1165. CallSlots[{MD, VC.VFunc.Offset}]
  1166. .ConstCSInfo[VC.Args]
  1167. .SummaryHasTypeTestAssumeUsers = true;
  1168. }
  1169. }
  1170. for (const FunctionSummary::ConstVCall &VC :
  1171. FS->type_checked_load_const_vcalls()) {
  1172. for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
  1173. CallSlots[{MD, VC.VFunc.Offset}]
  1174. .ConstCSInfo[VC.Args]
  1175. .SummaryTypeCheckedLoadUsers.push_back(FS);
  1176. }
  1177. }
  1178. }
  1179. }
  1180. }
  1181. // For each (type, offset) pair:
  1182. bool DidVirtualConstProp = false;
  1183. std::map<std::string, Function*> DevirtTargets;
  1184. for (auto &S : CallSlots) {
  1185. // Search each of the members of the type identifier for the virtual
  1186. // function implementation at offset S.first.ByteOffset, and add to
  1187. // TargetsForSlot.
  1188. std::vector<VirtualCallTarget> TargetsForSlot;
  1189. if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
  1190. S.first.ByteOffset)) {
  1191. WholeProgramDevirtResolution *Res = nullptr;
  1192. if (ExportSummary && isa<MDString>(S.first.TypeID))
  1193. Res = &ExportSummary
  1194. ->getOrInsertTypeIdSummary(
  1195. cast<MDString>(S.first.TypeID)->getString())
  1196. .WPDRes[S.first.ByteOffset];
  1197. if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) &&
  1198. tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first))
  1199. DidVirtualConstProp = true;
  1200. // Collect functions devirtualized at least for one call site for stats.
  1201. if (RemarksEnabled)
  1202. for (const auto &T : TargetsForSlot)
  1203. if (T.WasDevirt)
  1204. DevirtTargets[T.Fn->getName()] = T.Fn;
  1205. }
  1206. // CFI-specific: if we are exporting and any llvm.type.checked.load
  1207. // intrinsics were *not* devirtualized, we need to add the resulting
  1208. // llvm.type.test intrinsics to the function summaries so that the
  1209. // LowerTypeTests pass will export them.
  1210. if (ExportSummary && isa<MDString>(S.first.TypeID)) {
  1211. auto GUID =
  1212. GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString());
  1213. for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers)
  1214. FS->addTypeTest(GUID);
  1215. for (auto &CCS : S.second.ConstCSInfo)
  1216. for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers)
  1217. FS->addTypeTest(GUID);
  1218. }
  1219. }
  1220. if (RemarksEnabled) {
  1221. // Generate remarks for each devirtualized function.
  1222. for (const auto &DT : DevirtTargets) {
  1223. Function *F = DT.second;
  1224. DISubprogram *SP = F->getSubprogram();
  1225. emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP,
  1226. Twine("devirtualized ") + F->getName());
  1227. }
  1228. }
  1229. removeRedundantTypeTests();
  1230. // Rebuild each global we touched as part of virtual constant propagation to
  1231. // include the before and after bytes.
  1232. if (DidVirtualConstProp)
  1233. for (VTableBits &B : Bits)
  1234. rebuildGlobal(B);
  1235. return true;
  1236. }