ForwardControlFlowIntegrity.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. //===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===//
  2. //
  3. // This file is distributed under the University of Illinois Open Source
  4. // License. See LICENSE.TXT for details.
  5. //
  6. //===----------------------------------------------------------------------===//
  7. ///
  8. /// \file
  9. /// \brief A pass that instruments code with fast checks for indirect calls and
  10. /// hooks for a function to check violations.
  11. ///
  12. //===----------------------------------------------------------------------===//
  13. #define DEBUG_TYPE "cfi"
  14. #include "llvm/ADT/SmallVector.h"
  15. #include "llvm/ADT/Statistic.h"
  16. #include "llvm/Analysis/JumpInstrTableInfo.h"
  17. #include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
  18. #include "llvm/CodeGen/JumpInstrTables.h"
  19. #include "llvm/CodeGen/Passes.h"
  20. #include "llvm/IR/Attributes.h"
  21. #include "llvm/IR/CallSite.h"
  22. #include "llvm/IR/Constants.h"
  23. #include "llvm/IR/DerivedTypes.h"
  24. #include "llvm/IR/Function.h"
  25. #include "llvm/IR/GlobalValue.h"
  26. #include "llvm/IR/IRBuilder.h"
  27. #include "llvm/IR/InlineAsm.h"
  28. #include "llvm/IR/Instructions.h"
  29. #include "llvm/IR/LLVMContext.h"
  30. #include "llvm/IR/Module.h"
  31. #include "llvm/IR/Operator.h"
  32. #include "llvm/IR/Type.h"
  33. #include "llvm/IR/Verifier.h"
  34. #include "llvm/Pass.h"
  35. #include "llvm/Support/CommandLine.h"
  36. #include "llvm/Support/Debug.h"
  37. #include "llvm/Support/raw_ostream.h"
  38. using namespace llvm;
  39. STATISTIC(NumCFIIndirectCalls,
  40. "Number of indirect call sites rewritten by the CFI pass");
  41. char ForwardControlFlowIntegrity::ID = 0;
  42. INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi",
  43. "Control-Flow Integrity", true, true)
  44. INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
  45. INITIALIZE_PASS_DEPENDENCY(JumpInstrTables);
  46. INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi",
  47. "Control-Flow Integrity", true, true)
  48. ModulePass *llvm::createForwardControlFlowIntegrityPass() {
  49. return new ForwardControlFlowIntegrity();
  50. }
  51. ModulePass *llvm::createForwardControlFlowIntegrityPass(
  52. JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
  53. StringRef CFIFuncName) {
  54. return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing,
  55. CFIFuncName);
  56. }
  57. // Checks to see if a given CallSite is making an indirect call, including
  58. // cases where the indirect call is made through a bitcast.
  59. static bool isIndirectCall(CallSite &CS) {
  60. if (CS.getCalledFunction())
  61. return false;
  62. // Check the value to see if it is merely a bitcast of a function. In
  63. // this case, it will translate to a direct function call in the resulting
  64. // assembly, so we won't treat it as an indirect call here.
  65. const Value *V = CS.getCalledValue();
  66. if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
  67. return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
  68. }
  69. // Otherwise, since we know it's a call, it must be an indirect call
  70. return true;
  71. }
  72. static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning";
  73. ForwardControlFlowIntegrity::ForwardControlFlowIntegrity()
  74. : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single),
  75. CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") {
  76. initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
  77. }
  78. ForwardControlFlowIntegrity::ForwardControlFlowIntegrity(
  79. JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
  80. std::string CFIFuncName)
  81. : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType),
  82. CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) {
  83. initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
  84. }
  85. ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {}
  86. void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const {
  87. AU.addRequired<JumpInstrTableInfo>();
  88. AU.addRequired<JumpInstrTables>();
  89. }
  90. void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) {
  91. // To get the indirect calls, we iterate over all functions and iterate over
  92. // the list of basic blocks in each. We extract a total list of indirect calls
  93. // before modifying any of them, since our modifications will modify the list
  94. // of basic blocks.
  95. for (Function &F : M) {
  96. for (BasicBlock &BB : F) {
  97. for (Instruction &I : BB) {
  98. CallSite CS(&I);
  99. if (!(CS && isIndirectCall(CS)))
  100. continue;
  101. Value *CalledValue = CS.getCalledValue();
  102. // Don't rewrite this instruction if the indirect call is actually just
  103. // inline assembly, since our transformation will generate an invalid
  104. // module in that case.
  105. if (isa<InlineAsm>(CalledValue))
  106. continue;
  107. IndirectCalls.push_back(&I);
  108. }
  109. }
  110. }
  111. }
  112. void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M,
  113. CFITables &CFIT) {
  114. Type *Int64Ty = Type::getInt64Ty(M.getContext());
  115. for (Instruction *I : IndirectCalls) {
  116. CallSite CS(I);
  117. Value *CalledValue = CS.getCalledValue();
  118. // Get the function type for this call and look it up in the tables.
  119. Type *VTy = CalledValue->getType();
  120. PointerType *PTy = dyn_cast<PointerType>(VTy);
  121. Type *EltTy = PTy->getElementType();
  122. FunctionType *FunTy = dyn_cast<FunctionType>(EltTy);
  123. FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy);
  124. ++NumCFIIndirectCalls;
  125. Constant *JumpTableStart = nullptr;
  126. Constant *JumpTableMask = nullptr;
  127. Constant *JumpTableSize = nullptr;
  128. // Some call sites have function types that don't correspond to any
  129. // address-taken function in the module. This happens when function pointers
  130. // are passed in from external code.
  131. auto it = CFIT.find(TransformedTy);
  132. if (it == CFIT.end()) {
  133. // In this case, make sure that the function pointer will change by
  134. // setting the mask and the start to be 0 so that the transformed
  135. // function is 0.
  136. JumpTableStart = ConstantInt::get(Int64Ty, 0);
  137. JumpTableMask = ConstantInt::get(Int64Ty, 0);
  138. JumpTableSize = ConstantInt::get(Int64Ty, 0);
  139. } else {
  140. JumpTableStart = it->second.StartValue;
  141. JumpTableMask = it->second.MaskValue;
  142. JumpTableSize = it->second.Size;
  143. }
  144. rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask,
  145. JumpTableSize);
  146. }
  147. return;
  148. }
  149. bool ForwardControlFlowIntegrity::runOnModule(Module &M) {
  150. JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>();
  151. Type *Int64Ty = Type::getInt64Ty(M.getContext());
  152. Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
  153. // JumpInstrTableInfo stores information about the alignment of each entry.
  154. // The alignment returned by JumpInstrTableInfo is alignment in bytes, not
  155. // in the exponent.
  156. ByteAlignment = JITI->entryByteAlignment();
  157. LogByteAlignment = llvm::Log2_64(ByteAlignment);
  158. // Set up tables for control-flow integrity based on information about the
  159. // jump-instruction tables.
  160. CFITables CFIT;
  161. for (const auto &KV : JITI->getTables()) {
  162. uint64_t Size = static_cast<uint64_t>(KV.second.size());
  163. uint64_t TableSize = NextPowerOf2(Size);
  164. int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment;
  165. Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue);
  166. Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size);
  167. // The base of the table is defined to be the first jumptable function in
  168. // the table.
  169. Function *First = KV.second.begin()->second;
  170. Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy);
  171. CFIT[KV.first].StartValue = JumpTableStartValue;
  172. CFIT[KV.first].MaskValue = JumpTableMaskValue;
  173. CFIT[KV.first].Size = JumpTableSize;
  174. }
  175. if (CFIT.empty())
  176. return false;
  177. getIndirectCalls(M);
  178. if (!CFIEnforcing) {
  179. addWarningFunction(M);
  180. }
  181. // Update the instructions with the check and the indirect jump through our
  182. // table.
  183. updateIndirectCalls(M, CFIT);
  184. return true;
  185. }
  186. void ForwardControlFlowIntegrity::addWarningFunction(Module &M) {
  187. PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext());
  188. // Get the type of the Warning Function: void (i8*, i8*),
  189. // where the first argument is the name of the function in which the violation
  190. // occurs, and the second is the function pointer that violates CFI.
  191. SmallVector<Type *, 2> WarningFunArgs;
  192. WarningFunArgs.push_back(CharPtrTy);
  193. WarningFunArgs.push_back(CharPtrTy);
  194. FunctionType *WarningFunTy =
  195. FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false);
  196. if (!CFIFuncName.empty()) {
  197. Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy);
  198. if (!FailureFun)
  199. report_fatal_error("Could not get or insert the function specified by"
  200. " -cfi-func-name");
  201. } else {
  202. // The default warning function swallows the warning and lets the call
  203. // continue, since there's no generic way for it to print out this
  204. // information.
  205. Function *WarningFun = M.getFunction(cfi_failure_func_name);
  206. if (!WarningFun) {
  207. WarningFun =
  208. Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage,
  209. cfi_failure_func_name, &M);
  210. }
  211. BasicBlock *Entry =
  212. BasicBlock::Create(M.getContext(), "entry", WarningFun, 0);
  213. ReturnInst::Create(M.getContext(), Entry);
  214. }
  215. }
  216. void ForwardControlFlowIntegrity::rewriteFunctionPointer(
  217. Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart,
  218. Constant *JumpTableMask, Constant *JumpTableSize) {
  219. IRBuilder<> TempBuilder(I);
  220. Type *OrigFunType = FunPtr->getType();
  221. BasicBlock *CurBB = cast<BasicBlock>(I->getParent());
  222. Function *CurF = cast<Function>(CurBB->getParent());
  223. Type *Int64Ty = Type::getInt64Ty(M.getContext());
  224. Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty);
  225. Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty);
  226. Value *NewFunPtr = nullptr;
  227. Value *Check = nullptr;
  228. switch (CFIType) {
  229. case CFIntegrity::Sub: {
  230. // This is the subtract, mask, and add version.
  231. // Subtract from the base.
  232. Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
  233. // Mask the difference to force this to be a table offset.
  234. Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask);
  235. // Add it back to the base.
  236. Value *Result = TempBuilder.CreateAdd(And, TStartInt);
  237. // Convert it back into a function pointer that we can call.
  238. NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
  239. break;
  240. }
  241. case CFIntegrity::Ror: {
  242. // This is the subtract and rotate version.
  243. // Rotate right by the alignment value. The optimizer should recognize
  244. // this sequence as a rotation.
  245. // This cast is safe, since unsigned is always a subset of uint64_t.
  246. uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment);
  247. Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64);
  248. Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64);
  249. // Subtract from the base.
  250. Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
  251. // Create the equivalent of a rotate-right instruction.
  252. Value *Shr = TempBuilder.CreateLShr(Sub, RightShift);
  253. Value *Shl = TempBuilder.CreateShl(Sub, LeftShift);
  254. Value *Or = TempBuilder.CreateOr(Shr, Shl);
  255. // Perform unsigned comparison to check for inclusion in the table.
  256. Check = TempBuilder.CreateICmpULT(Or, JumpTableSize);
  257. NewFunPtr = FunPtr;
  258. break;
  259. }
  260. case CFIntegrity::Add: {
  261. // This is the mask and add version.
  262. // Mask the function pointer to turn it into an offset into the table.
  263. Value *And = TempBuilder.CreateAnd(TI, JumpTableMask);
  264. // Then or this offset to the base and get the pointer value.
  265. Value *Result = TempBuilder.CreateAdd(And, TStartInt);
  266. // Convert it back into a function pointer that we can call.
  267. NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
  268. break;
  269. }
  270. }
  271. if (!CFIEnforcing) {
  272. // If a check hasn't been added (in the rotation version), then check to see
  273. // if it's the same as the original function. This check determines whether
  274. // or not we call the CFI failure function.
  275. if (!Check)
  276. Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr);
  277. BasicBlock *InvalidPtrBlock =
  278. BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0);
  279. BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I);
  280. // Remove the unconditional branch that connects the two blocks.
  281. TerminatorInst *TermInst = CurBB->getTerminator();
  282. TermInst->eraseFromParent();
  283. // Add a conditional branch that depends on the Check above.
  284. BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB);
  285. // Call the warning function for this pointer, then continue.
  286. Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock);
  287. insertWarning(M, InvalidPtrBlock, BI, FunPtr);
  288. } else {
  289. // Modify the instruction to call this value.
  290. CallSite CS(I);
  291. CS.setCalledFunction(NewFunPtr);
  292. }
  293. }
  294. void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block,
  295. Instruction *I, Value *FunPtr) {
  296. Function *ParentFun = cast<Function>(Block->getParent());
  297. // Get the function to call right before the instruction.
  298. Function *WarningFun = nullptr;
  299. if (CFIFuncName.empty()) {
  300. WarningFun = M.getFunction(cfi_failure_func_name);
  301. } else {
  302. WarningFun = M.getFunction(CFIFuncName);
  303. }
  304. assert(WarningFun && "Could not find the CFI failure function");
  305. Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
  306. IRBuilder<> WarningInserter(I);
  307. // Create a mergeable GlobalVariable containing the name of the function.
  308. Value *ParentNameGV =
  309. WarningInserter.CreateGlobalString(ParentFun->getName());
  310. Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy);
  311. Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy);
  312. WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr);
  313. }