Browse Source

Make SwitchInstProfUpdateWrapper safer

While prof branch_weights inconsistencies are being fixed patch
by patch (pass by pass) we need SwitchInstProfUpdateWrapper to
be safe with respect to inconsistent metadata that can come from
passes that have not been fixed yet. See the bug found by @nikic
in https://reviews.llvm.org/D62126.

This patch introduces one more state (called Invalid) to the
wrapper class that allows users to work with the underlying
SwitchInst ignoring the prof metadata changes.

Created a unit test for the SwitchInstProfUpdateWrapper class.

Reviewers: davidx, nikic, eraman, reames, chandlerc
Reviewed By: davidx
Differential Revision: https://reviews.llvm.org/D62656

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@362473 91177308-0d34-0410-b5e6-96231b3b80d8
Yevgeny Rouban 6 years ago
parent
commit
48973d77b6
3 changed files with 132 additions and 24 deletions
  1. 14 6
      include/llvm/IR/Instructions.h
  2. 39 18
      lib/IR/Instructions.cpp
  3. 79 0
      unittests/IR/InstructionsTest.cpp

+ 14 - 6
include/llvm/IR/Instructions.h

@@ -3439,15 +3439,24 @@ public:
 /// their prof branch_weights metadata.
 /// their prof branch_weights metadata.
 class SwitchInstProfUpdateWrapper {
 class SwitchInstProfUpdateWrapper {
   SwitchInst &SI;
   SwitchInst &SI;
-  Optional<SmallVector<uint32_t, 8> > Weights;
-  bool Changed = false;
+  Optional<SmallVector<uint32_t, 8> > Weights = None;
+
+  // Sticky invalid state is needed to safely ignore operations with prof data
+  // in cases where SwitchInstProfUpdateWrapper is created from SwitchInst
+  // with inconsistent prof data. TODO: once we fix all prof data
+  // inconsistencies we can turn invalid state to assertions.
+  enum {
+    Invalid,
+    Initialized,
+    Changed
+  } State = Invalid;
 
 
 protected:
 protected:
   static MDNode *getProfBranchWeightsMD(const SwitchInst &SI);
   static MDNode *getProfBranchWeightsMD(const SwitchInst &SI);
 
 
   MDNode *buildProfBranchWeightsMD();
   MDNode *buildProfBranchWeightsMD();
 
 
-  Optional<SmallVector<uint32_t, 8> > getProfBranchWeights();
+  void init();
 
 
 public:
 public:
   using CaseWeightOpt = Optional<uint32_t>;
   using CaseWeightOpt = Optional<uint32_t>;
@@ -3455,11 +3464,10 @@ public:
   SwitchInst &operator*() { return SI; }
   SwitchInst &operator*() { return SI; }
   operator SwitchInst *() { return &SI; }
   operator SwitchInst *() { return &SI; }
 
 
-  SwitchInstProfUpdateWrapper(SwitchInst &SI)
-      : SI(SI), Weights(getProfBranchWeights()) {}
+  SwitchInstProfUpdateWrapper(SwitchInst &SI) : SI(SI) { init(); }
 
 
   ~SwitchInstProfUpdateWrapper() {
   ~SwitchInstProfUpdateWrapper() {
-    if (Changed)
+    if (State == Changed)
       SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD());
       SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD());
   }
   }
 
 

+ 39 - 18
lib/IR/Instructions.cpp

@@ -45,6 +45,12 @@
 
 
 using namespace llvm;
 using namespace llvm;
 
 
+static cl::opt<bool> SwitchInstProfUpdateWrapperStrict(
+    "switch-inst-prof-update-wrapper-strict", cl::Hidden,
+    cl::desc("Assert that prof branch_weights metadata is valid when creating "
+             "an instance of SwitchInstProfUpdateWrapper"),
+    cl::init(false));
+
 //===----------------------------------------------------------------------===//
 //===----------------------------------------------------------------------===//
 //                            AllocaInst Class
 //                            AllocaInst Class
 //===----------------------------------------------------------------------===//
 //===----------------------------------------------------------------------===//
@@ -3880,7 +3886,7 @@ SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) {
 }
 }
 
 
 MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
 MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
-  assert(Changed && "called only if metadata has changed");
+  assert(State == Changed && "called only if metadata has changed");
 
 
   if (!Weights)
   if (!Weights)
     return nullptr;
     return nullptr;
@@ -3897,11 +3903,20 @@ MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
   return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights);
   return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights);
 }
 }
 
 
-Optional<SmallVector<uint32_t, 8> >
-SwitchInstProfUpdateWrapper::getProfBranchWeights() {
+void SwitchInstProfUpdateWrapper::init() {
   MDNode *ProfileData = getProfBranchWeightsMD(SI);
   MDNode *ProfileData = getProfBranchWeightsMD(SI);
-  if (!ProfileData)
-    return None;
+  if (!ProfileData) {
+    State = Initialized;
+    return;
+  }
+
+  if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
+    State = Invalid;
+    if (SwitchInstProfUpdateWrapperStrict)
+      assert(!"number of prof branch_weights metadata operands corresponds to"
+              " number of succesors");
+    return;
+  }
 
 
   SmallVector<uint32_t, 8> Weights;
   SmallVector<uint32_t, 8> Weights;
   for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
   for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
@@ -3909,7 +3924,8 @@ SwitchInstProfUpdateWrapper::getProfBranchWeights() {
     uint32_t CW = C->getValue().getZExtValue();
     uint32_t CW = C->getValue().getZExtValue();
     Weights.push_back(CW);
     Weights.push_back(CW);
   }
   }
-  return Weights;
+  State = Initialized;
+  this->Weights = std::move(Weights);
 }
 }
 
 
 SwitchInst::CaseIt
 SwitchInst::CaseIt
@@ -3917,7 +3933,7 @@ SwitchInstProfUpdateWrapper::removeCase(SwitchInst::CaseIt I) {
   if (Weights) {
   if (Weights) {
     assert(SI.getNumSuccessors() == Weights->size() &&
     assert(SI.getNumSuccessors() == Weights->size() &&
            "num of prof branch_weights must accord with num of successors");
            "num of prof branch_weights must accord with num of successors");
-    Changed = true;
+    State = Changed;
     // Copy the last case to the place of the removed one and shrink.
     // Copy the last case to the place of the removed one and shrink.
     // This is tightly coupled with the way SwitchInst::removeCase() removes
     // This is tightly coupled with the way SwitchInst::removeCase() removes
     // the cases in SwitchInst::removeCase(CaseIt).
     // the cases in SwitchInst::removeCase(CaseIt).
@@ -3932,12 +3948,15 @@ void SwitchInstProfUpdateWrapper::addCase(
     SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
     SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
   SI.addCase(OnVal, Dest);
   SI.addCase(OnVal, Dest);
 
 
+  if (State == Invalid)
+    return;
+
   if (!Weights && W && *W) {
   if (!Weights && W && *W) {
-    Changed = true;
+    State = Changed;
     Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
     Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
     Weights.getValue()[SI.getNumSuccessors() - 1] = *W;
     Weights.getValue()[SI.getNumSuccessors() - 1] = *W;
   } else if (Weights) {
   } else if (Weights) {
-    Changed = true;
+    State = Changed;
     Weights.getValue().push_back(W ? *W : 0);
     Weights.getValue().push_back(W ? *W : 0);
   }
   }
   if (Weights)
   if (Weights)
@@ -3948,10 +3967,11 @@ void SwitchInstProfUpdateWrapper::addCase(
 SymbolTableList<Instruction>::iterator
 SymbolTableList<Instruction>::iterator
 SwitchInstProfUpdateWrapper::eraseFromParent() {
 SwitchInstProfUpdateWrapper::eraseFromParent() {
   // Instruction is erased. Mark as unchanged to not touch it in the destructor.
   // Instruction is erased. Mark as unchanged to not touch it in the destructor.
-  Changed = false;
-
-  if (Weights)
-    Weights->resize(0);
+  if (State != Invalid) {
+    State = Initialized;
+    if (Weights)
+      Weights->resize(0);
+  }
   return SI.eraseFromParent();
   return SI.eraseFromParent();
 }
 }
 
 
@@ -3964,7 +3984,7 @@ SwitchInstProfUpdateWrapper::getSuccessorWeight(unsigned idx) {
 
 
 void SwitchInstProfUpdateWrapper::setSuccessorWeight(
 void SwitchInstProfUpdateWrapper::setSuccessorWeight(
     unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
     unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
-  if (!W)
+  if (!W || State == Invalid)
     return;
     return;
 
 
   if (!Weights && *W)
   if (!Weights && *W)
@@ -3973,7 +3993,7 @@ void SwitchInstProfUpdateWrapper::setSuccessorWeight(
   if (Weights) {
   if (Weights) {
     auto &OldW = Weights.getValue()[idx];
     auto &OldW = Weights.getValue()[idx];
     if (*W != OldW) {
     if (*W != OldW) {
-      Changed = true;
+      State = Changed;
       OldW = *W;
       OldW = *W;
     }
     }
   }
   }
@@ -3983,9 +4003,10 @@ SwitchInstProfUpdateWrapper::CaseWeightOpt
 SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI,
 SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI,
                                                 unsigned idx) {
                                                 unsigned idx) {
   if (MDNode *ProfileData = getProfBranchWeightsMD(SI))
   if (MDNode *ProfileData = getProfBranchWeightsMD(SI))
-    return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1))
-        ->getValue()
-        .getZExtValue();
+    if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1)
+      return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1))
+          ->getValue()
+          .getZExtValue();
 
 
   return None;
   return None;
 }
 }

+ 79 - 0
unittests/IR/InstructionsTest.cpp

@@ -753,6 +753,85 @@ TEST(InstructionsTest, SwitchInst) {
   EXPECT_EQ(BB1.get(), Handle.getCaseSuccessor());
   EXPECT_EQ(BB1.get(), Handle.getCaseSuccessor());
 }
 }
 
 
+TEST(InstructionsTest, SwitchInstProfUpdateWrapper) {
+  LLVMContext C;
+
+  std::unique_ptr<BasicBlock> BB1, BB2, BB3;
+  BB1.reset(BasicBlock::Create(C));
+  BB2.reset(BasicBlock::Create(C));
+  BB3.reset(BasicBlock::Create(C));
+
+  // We create block 0 after the others so that it gets destroyed first and
+  // clears the uses of the other basic blocks.
+  std::unique_ptr<BasicBlock> BB0(BasicBlock::Create(C));
+
+  auto *Int32Ty = Type::getInt32Ty(C);
+
+  SwitchInst *SI =
+      SwitchInst::Create(UndefValue::get(Int32Ty), BB0.get(), 4, BB0.get());
+  SI->addCase(ConstantInt::get(Int32Ty, 1), BB1.get());
+  SI->addCase(ConstantInt::get(Int32Ty, 2), BB2.get());
+  SI->setMetadata(LLVMContext::MD_prof,
+                  MDBuilder(C).createBranchWeights({ 9, 1, 22 }));
+
+  {
+    SwitchInstProfUpdateWrapper SIW(*SI);
+    EXPECT_EQ(*SIW.getSuccessorWeight(0), 9u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(1), 1u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+    SIW.setSuccessorWeight(0, 99u);
+    SIW.setSuccessorWeight(1, 11u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+  }
+
+  { // Create another wrapper and check that the data persist.
+    SwitchInstProfUpdateWrapper SIW(*SI);
+    EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+  }
+
+  // Make prof data invalid by adding one extra weight.
+  SI->setMetadata(LLVMContext::MD_prof, MDBuilder(C).createBranchWeights(
+                                            { 99, 11, 22, 33 })); // extra
+  { // Invalid prof data makes wrapper act as if there were no prof data.
+    SwitchInstProfUpdateWrapper SIW(*SI);
+    ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue());
+    ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue());
+    ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue());
+    SIW.addCase(ConstantInt::get(Int32Ty, 3), BB3.get(), 39);
+    ASSERT_FALSE(SIW.getSuccessorWeight(3).hasValue()); // did not add weight 39
+  }
+
+  { // With added 3rd case the prof data become consistent with num of cases.
+    SwitchInstProfUpdateWrapper SIW(*SI);
+    EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(3), 33u);
+  }
+
+  // Make prof data invalid by removing one extra weight.
+  SI->setMetadata(LLVMContext::MD_prof,
+                  MDBuilder(C).createBranchWeights({ 99, 11, 22 })); // shorter
+  { // Invalid prof data makes wrapper act as if there were no prof data.
+    SwitchInstProfUpdateWrapper SIW(*SI);
+    ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue());
+    ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue());
+    ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue());
+    SIW.removeCase(SwitchInst::CaseIt(SI, 2));
+  }
+
+  { // With removed 3rd case the prof data become consistent with num of cases.
+    SwitchInstProfUpdateWrapper SIW(*SI);
+    EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
+    EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+  }
+}
+
 TEST(InstructionsTest, CommuteShuffleMask) {
 TEST(InstructionsTest, CommuteShuffleMask) {
   SmallVector<int, 16> Indices({-1, 0, 7});
   SmallVector<int, 16> Indices({-1, 0, 7});
   ShuffleVectorInst::commuteShuffleMask(Indices, 4);
   ShuffleVectorInst::commuteShuffleMask(Indices, 4);