123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374 |
- //===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===//
- //
- // This file is distributed under the University of Illinois Open Source
- // License. See LICENSE.TXT for details.
- //
- //===----------------------------------------------------------------------===//
- ///
- /// \file
- /// \brief A pass that instruments code with fast checks for indirect calls and
- /// hooks for a function to check violations.
- ///
- //===----------------------------------------------------------------------===//
- #define DEBUG_TYPE "cfi"
- #include "llvm/ADT/SmallVector.h"
- #include "llvm/ADT/Statistic.h"
- #include "llvm/Analysis/JumpInstrTableInfo.h"
- #include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
- #include "llvm/CodeGen/JumpInstrTables.h"
- #include "llvm/CodeGen/Passes.h"
- #include "llvm/IR/Attributes.h"
- #include "llvm/IR/CallSite.h"
- #include "llvm/IR/Constants.h"
- #include "llvm/IR/DerivedTypes.h"
- #include "llvm/IR/Function.h"
- #include "llvm/IR/GlobalValue.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/InlineAsm.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/IR/LLVMContext.h"
- #include "llvm/IR/Module.h"
- #include "llvm/IR/Operator.h"
- #include "llvm/IR/Type.h"
- #include "llvm/IR/Verifier.h"
- #include "llvm/Pass.h"
- #include "llvm/Support/CommandLine.h"
- #include "llvm/Support/Debug.h"
- #include "llvm/Support/raw_ostream.h"
- using namespace llvm;
- STATISTIC(NumCFIIndirectCalls,
- "Number of indirect call sites rewritten by the CFI pass");
- char ForwardControlFlowIntegrity::ID = 0;
- INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi",
- "Control-Flow Integrity", true, true)
- INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
- INITIALIZE_PASS_DEPENDENCY(JumpInstrTables);
- INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi",
- "Control-Flow Integrity", true, true)
- ModulePass *llvm::createForwardControlFlowIntegrityPass() {
- return new ForwardControlFlowIntegrity();
- }
- ModulePass *llvm::createForwardControlFlowIntegrityPass(
- JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
- StringRef CFIFuncName) {
- return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing,
- CFIFuncName);
- }
- // Checks to see if a given CallSite is making an indirect call, including
- // cases where the indirect call is made through a bitcast.
- static bool isIndirectCall(CallSite &CS) {
- if (CS.getCalledFunction())
- return false;
- // Check the value to see if it is merely a bitcast of a function. In
- // this case, it will translate to a direct function call in the resulting
- // assembly, so we won't treat it as an indirect call here.
- const Value *V = CS.getCalledValue();
- if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
- return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
- }
- // Otherwise, since we know it's a call, it must be an indirect call
- return true;
- }
- static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning";
- ForwardControlFlowIntegrity::ForwardControlFlowIntegrity()
- : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single),
- CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") {
- initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
- }
- ForwardControlFlowIntegrity::ForwardControlFlowIntegrity(
- JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
- std::string CFIFuncName)
- : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType),
- CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) {
- initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
- }
- ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {}
- void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const {
- AU.addRequired<JumpInstrTableInfo>();
- AU.addRequired<JumpInstrTables>();
- }
- void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) {
- // To get the indirect calls, we iterate over all functions and iterate over
- // the list of basic blocks in each. We extract a total list of indirect calls
- // before modifying any of them, since our modifications will modify the list
- // of basic blocks.
- for (Function &F : M) {
- for (BasicBlock &BB : F) {
- for (Instruction &I : BB) {
- CallSite CS(&I);
- if (!(CS && isIndirectCall(CS)))
- continue;
- Value *CalledValue = CS.getCalledValue();
- // Don't rewrite this instruction if the indirect call is actually just
- // inline assembly, since our transformation will generate an invalid
- // module in that case.
- if (isa<InlineAsm>(CalledValue))
- continue;
- IndirectCalls.push_back(&I);
- }
- }
- }
- }
- void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M,
- CFITables &CFIT) {
- Type *Int64Ty = Type::getInt64Ty(M.getContext());
- for (Instruction *I : IndirectCalls) {
- CallSite CS(I);
- Value *CalledValue = CS.getCalledValue();
- // Get the function type for this call and look it up in the tables.
- Type *VTy = CalledValue->getType();
- PointerType *PTy = dyn_cast<PointerType>(VTy);
- Type *EltTy = PTy->getElementType();
- FunctionType *FunTy = dyn_cast<FunctionType>(EltTy);
- FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy);
- ++NumCFIIndirectCalls;
- Constant *JumpTableStart = nullptr;
- Constant *JumpTableMask = nullptr;
- Constant *JumpTableSize = nullptr;
- // Some call sites have function types that don't correspond to any
- // address-taken function in the module. This happens when function pointers
- // are passed in from external code.
- auto it = CFIT.find(TransformedTy);
- if (it == CFIT.end()) {
- // In this case, make sure that the function pointer will change by
- // setting the mask and the start to be 0 so that the transformed
- // function is 0.
- JumpTableStart = ConstantInt::get(Int64Ty, 0);
- JumpTableMask = ConstantInt::get(Int64Ty, 0);
- JumpTableSize = ConstantInt::get(Int64Ty, 0);
- } else {
- JumpTableStart = it->second.StartValue;
- JumpTableMask = it->second.MaskValue;
- JumpTableSize = it->second.Size;
- }
- rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask,
- JumpTableSize);
- }
- return;
- }
- bool ForwardControlFlowIntegrity::runOnModule(Module &M) {
- JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>();
- Type *Int64Ty = Type::getInt64Ty(M.getContext());
- Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
- // JumpInstrTableInfo stores information about the alignment of each entry.
- // The alignment returned by JumpInstrTableInfo is alignment in bytes, not
- // in the exponent.
- ByteAlignment = JITI->entryByteAlignment();
- LogByteAlignment = llvm::Log2_64(ByteAlignment);
- // Set up tables for control-flow integrity based on information about the
- // jump-instruction tables.
- CFITables CFIT;
- for (const auto &KV : JITI->getTables()) {
- uint64_t Size = static_cast<uint64_t>(KV.second.size());
- uint64_t TableSize = NextPowerOf2(Size);
- int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment;
- Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue);
- Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size);
- // The base of the table is defined to be the first jumptable function in
- // the table.
- Function *First = KV.second.begin()->second;
- Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy);
- CFIT[KV.first].StartValue = JumpTableStartValue;
- CFIT[KV.first].MaskValue = JumpTableMaskValue;
- CFIT[KV.first].Size = JumpTableSize;
- }
- if (CFIT.empty())
- return false;
- getIndirectCalls(M);
- if (!CFIEnforcing) {
- addWarningFunction(M);
- }
- // Update the instructions with the check and the indirect jump through our
- // table.
- updateIndirectCalls(M, CFIT);
- return true;
- }
- void ForwardControlFlowIntegrity::addWarningFunction(Module &M) {
- PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext());
- // Get the type of the Warning Function: void (i8*, i8*),
- // where the first argument is the name of the function in which the violation
- // occurs, and the second is the function pointer that violates CFI.
- SmallVector<Type *, 2> WarningFunArgs;
- WarningFunArgs.push_back(CharPtrTy);
- WarningFunArgs.push_back(CharPtrTy);
- FunctionType *WarningFunTy =
- FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false);
- if (!CFIFuncName.empty()) {
- Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy);
- if (!FailureFun)
- report_fatal_error("Could not get or insert the function specified by"
- " -cfi-func-name");
- } else {
- // The default warning function swallows the warning and lets the call
- // continue, since there's no generic way for it to print out this
- // information.
- Function *WarningFun = M.getFunction(cfi_failure_func_name);
- if (!WarningFun) {
- WarningFun =
- Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage,
- cfi_failure_func_name, &M);
- }
- BasicBlock *Entry =
- BasicBlock::Create(M.getContext(), "entry", WarningFun, 0);
- ReturnInst::Create(M.getContext(), Entry);
- }
- }
- void ForwardControlFlowIntegrity::rewriteFunctionPointer(
- Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart,
- Constant *JumpTableMask, Constant *JumpTableSize) {
- IRBuilder<> TempBuilder(I);
- Type *OrigFunType = FunPtr->getType();
- BasicBlock *CurBB = cast<BasicBlock>(I->getParent());
- Function *CurF = cast<Function>(CurBB->getParent());
- Type *Int64Ty = Type::getInt64Ty(M.getContext());
- Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty);
- Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty);
- Value *NewFunPtr = nullptr;
- Value *Check = nullptr;
- switch (CFIType) {
- case CFIntegrity::Sub: {
- // This is the subtract, mask, and add version.
- // Subtract from the base.
- Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
- // Mask the difference to force this to be a table offset.
- Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask);
- // Add it back to the base.
- Value *Result = TempBuilder.CreateAdd(And, TStartInt);
- // Convert it back into a function pointer that we can call.
- NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
- break;
- }
- case CFIntegrity::Ror: {
- // This is the subtract and rotate version.
- // Rotate right by the alignment value. The optimizer should recognize
- // this sequence as a rotation.
- // This cast is safe, since unsigned is always a subset of uint64_t.
- uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment);
- Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64);
- Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64);
- // Subtract from the base.
- Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
- // Create the equivalent of a rotate-right instruction.
- Value *Shr = TempBuilder.CreateLShr(Sub, RightShift);
- Value *Shl = TempBuilder.CreateShl(Sub, LeftShift);
- Value *Or = TempBuilder.CreateOr(Shr, Shl);
- // Perform unsigned comparison to check for inclusion in the table.
- Check = TempBuilder.CreateICmpULT(Or, JumpTableSize);
- NewFunPtr = FunPtr;
- break;
- }
- case CFIntegrity::Add: {
- // This is the mask and add version.
- // Mask the function pointer to turn it into an offset into the table.
- Value *And = TempBuilder.CreateAnd(TI, JumpTableMask);
- // Then or this offset to the base and get the pointer value.
- Value *Result = TempBuilder.CreateAdd(And, TStartInt);
- // Convert it back into a function pointer that we can call.
- NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
- break;
- }
- }
- if (!CFIEnforcing) {
- // If a check hasn't been added (in the rotation version), then check to see
- // if it's the same as the original function. This check determines whether
- // or not we call the CFI failure function.
- if (!Check)
- Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr);
- BasicBlock *InvalidPtrBlock =
- BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0);
- BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I);
- // Remove the unconditional branch that connects the two blocks.
- TerminatorInst *TermInst = CurBB->getTerminator();
- TermInst->eraseFromParent();
- // Add a conditional branch that depends on the Check above.
- BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB);
- // Call the warning function for this pointer, then continue.
- Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock);
- insertWarning(M, InvalidPtrBlock, BI, FunPtr);
- } else {
- // Modify the instruction to call this value.
- CallSite CS(I);
- CS.setCalledFunction(NewFunPtr);
- }
- }
- void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block,
- Instruction *I, Value *FunPtr) {
- Function *ParentFun = cast<Function>(Block->getParent());
- // Get the function to call right before the instruction.
- Function *WarningFun = nullptr;
- if (CFIFuncName.empty()) {
- WarningFun = M.getFunction(cfi_failure_func_name);
- } else {
- WarningFun = M.getFunction(CFIFuncName);
- }
- assert(WarningFun && "Could not find the CFI failure function");
- Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
- IRBuilder<> WarningInserter(I);
- // Create a mergeable GlobalVariable containing the name of the function.
- Value *ParentNameGV =
- WarningInserter.CreateGlobalString(ParentFun->getName());
- Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy);
- Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy);
- WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr);
- }
|