VPlanPredicatorTest.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. //===- llvm/unittests/Transforms/Vectorize/VPlanPredicatorTest.cpp -----===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "../lib/Transforms/Vectorize/VPlanPredicator.h"
  9. #include "VPlanTestBase.h"
  10. #include "gtest/gtest.h"
  11. namespace llvm {
  12. namespace {
  13. class VPlanPredicatorTest : public VPlanTestBase {};
  14. TEST_F(VPlanPredicatorTest, BasicPredicatorTest) {
  15. const char *ModuleString =
  16. "@arr = common global [8 x [8 x i64]] "
  17. "zeroinitializer, align 16\n"
  18. "@arr2 = common global [8 x [8 x i64]] "
  19. "zeroinitializer, align 16\n"
  20. "@arr3 = common global [8 x [8 x i64]] "
  21. "zeroinitializer, align 16\n"
  22. "define void @f(i64 %n1) {\n"
  23. "entry:\n"
  24. " br label %for.cond1.preheader\n"
  25. "for.cond1.preheader: \n"
  26. " %i1.029 = phi i64 [ 0, %entry ], [ %inc14, %for.inc13 ]\n"
  27. " br label %for.body3\n"
  28. "for.body3: \n"
  29. " %i2.028 = phi i64 [ 0, %for.cond1.preheader ], [ %inc, %for.inc ]\n"
  30. " %arrayidx4 = getelementptr inbounds [8 x [8 x i64]], [8 x [8 x i64]]* "
  31. "@arr, i64 0, i64 %i2.028, i64 %i1.029\n"
  32. " %0 = load i64, i64* %arrayidx4, align 8\n"
  33. " %cmp5 = icmp ugt i64 %0, 10\n"
  34. " br i1 %cmp5, label %if.then, label %for.inc\n"
  35. "if.then: \n"
  36. " %arrayidx7 = getelementptr inbounds [8 x [8 x i64]], [8 x [8 x i64]]* "
  37. "@arr2, i64 0, i64 %i2.028, i64 %i1.029\n"
  38. " %1 = load i64, i64* %arrayidx7, align 8\n"
  39. " %cmp8 = icmp ugt i64 %1, 100\n"
  40. " br i1 %cmp8, label %if.then9, label %for.inc\n"
  41. "if.then9: \n"
  42. " %add = add nuw nsw i64 %i2.028, %i1.029\n"
  43. " %arrayidx11 = getelementptr inbounds [8 x [8 x i64]], [8 x [8 x "
  44. "i64]]* @arr3, i64 0, i64 %i2.028, i64 %i1.029\n"
  45. " store i64 %add, i64* %arrayidx11, align 8\n"
  46. " br label %for.inc\n"
  47. "for.inc: \n"
  48. " %inc = add nuw nsw i64 %i2.028, 1\n"
  49. " %exitcond = icmp eq i64 %inc, 8\n"
  50. " br i1 %exitcond, label %for.inc13, label %for.body3\n"
  51. "for.inc13: \n"
  52. " %inc14 = add nuw nsw i64 %i1.029, 1\n"
  53. " %exitcond30 = icmp eq i64 %inc14, 8\n"
  54. " br i1 %exitcond30, label %for.end15, label %for.cond1.preheader\n"
  55. "for.end15: \n"
  56. " ret void\n"
  57. "}\n";
  58. Module &M = parseModule(ModuleString);
  59. Function *F = M.getFunction("f");
  60. BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor();
  61. auto Plan = buildHCFG(LoopHeader);
  62. VPRegionBlock *TopRegion = cast<VPRegionBlock>(Plan->getEntry());
  63. VPBlockBase *PH = TopRegion->getEntry();
  64. VPBlockBase *H = PH->getSingleSuccessor();
  65. VPBlockBase *InnerLoopH = H->getSingleSuccessor();
  66. VPBlockBase *OuterIf = InnerLoopH->getSuccessors()[0];
  67. VPBlockBase *InnerLoopLatch = InnerLoopH->getSuccessors()[1];
  68. VPBlockBase *InnerIf = OuterIf->getSuccessors()[0];
  69. VPValue *CBV1 = InnerLoopH->getCondBit();
  70. VPValue *CBV2 = OuterIf->getCondBit();
  71. // Apply predication.
  72. VPlanPredicator VPP(*Plan);
  73. VPP.predicate();
  74. VPBlockBase *InnerLoopLinSucc = InnerLoopH->getSingleSuccessor();
  75. VPBlockBase *OuterIfLinSucc = OuterIf->getSingleSuccessor();
  76. VPBlockBase *InnerIfLinSucc = InnerIf->getSingleSuccessor();
  77. VPValue *OuterIfPred = OuterIf->getPredicate();
  78. VPInstruction *InnerAnd =
  79. cast<VPInstruction>(InnerIf->getEntryBasicBlock()->begin());
  80. VPValue *InnerIfPred = InnerIf->getPredicate();
  81. // Test block predicates
  82. EXPECT_NE(nullptr, CBV1);
  83. EXPECT_NE(nullptr, CBV2);
  84. EXPECT_NE(nullptr, InnerAnd);
  85. EXPECT_EQ(CBV1, OuterIfPred);
  86. EXPECT_EQ(InnerAnd->getOpcode(), Instruction::And);
  87. EXPECT_EQ(InnerAnd->getOperand(0), CBV1);
  88. EXPECT_EQ(InnerAnd->getOperand(1), CBV2);
  89. EXPECT_EQ(InnerIfPred, InnerAnd);
  90. // Test Linearization
  91. EXPECT_EQ(InnerLoopLinSucc, OuterIf);
  92. EXPECT_EQ(OuterIfLinSucc, InnerIf);
  93. EXPECT_EQ(InnerIfLinSucc, InnerLoopLatch);
  94. }
  95. // Test generation of Not and Or during predication.
  96. TEST_F(VPlanPredicatorTest, PredicatorNegOrTest) {
  97. const char *ModuleString =
  98. "@arr = common global [100 x [100 x i32]] zeroinitializer, align 16\n"
  99. "@arr2 = common global [100 x [100 x i32]] zeroinitializer, align 16\n"
  100. "@arr3 = common global [100 x [100 x i32]] zeroinitializer, align 16\n"
  101. "define void @foo() {\n"
  102. "entry:\n"
  103. " br label %for.cond1.preheader\n"
  104. "for.cond1.preheader: \n"
  105. " %indvars.iv42 = phi i64 [ 0, %entry ], [ %indvars.iv.next43, "
  106. "%for.inc22 ]\n"
  107. " br label %for.body3\n"
  108. "for.body3: \n"
  109. " %indvars.iv = phi i64 [ 0, %for.cond1.preheader ], [ "
  110. "%indvars.iv.next, %if.end21 ]\n"
  111. " %arrayidx5 = getelementptr inbounds [100 x [100 x i32]], [100 x [100 "
  112. "x i32]]* @arr, i64 0, i64 %indvars.iv, i64 %indvars.iv42\n"
  113. " %0 = load i32, i32* %arrayidx5, align 4\n"
  114. " %cmp6 = icmp slt i32 %0, 100\n"
  115. " br i1 %cmp6, label %if.then, label %if.end21\n"
  116. "if.then: \n"
  117. " %cmp7 = icmp sgt i32 %0, 10\n"
  118. " br i1 %cmp7, label %if.then8, label %if.else\n"
  119. "if.then8: \n"
  120. " %add = add nsw i32 %0, 10\n"
  121. " %arrayidx12 = getelementptr inbounds [100 x [100 x i32]], [100 x [100 "
  122. "x i32]]* @arr2, i64 0, i64 %indvars.iv, i64 %indvars.iv42\n"
  123. " store i32 %add, i32* %arrayidx12, align 4\n"
  124. " br label %if.end\n"
  125. "if.else: \n"
  126. " %sub = add nsw i32 %0, -10\n"
  127. " %arrayidx16 = getelementptr inbounds [100 x [100 x i32]], [100 x [100 "
  128. "x i32]]* @arr3, i64 0, i64 %indvars.iv, i64 %indvars.iv42\n"
  129. " store i32 %sub, i32* %arrayidx16, align 4\n"
  130. " br label %if.end\n"
  131. "if.end: \n"
  132. " store i32 222, i32* %arrayidx5, align 4\n"
  133. " br label %if.end21\n"
  134. "if.end21: \n"
  135. " %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1\n"
  136. " %exitcond = icmp eq i64 %indvars.iv.next, 100\n"
  137. " br i1 %exitcond, label %for.inc22, label %for.body3\n"
  138. "for.inc22: \n"
  139. " %indvars.iv.next43 = add nuw nsw i64 %indvars.iv42, 1\n"
  140. " %exitcond44 = icmp eq i64 %indvars.iv.next43, 100\n"
  141. " br i1 %exitcond44, label %for.end24, label %for.cond1.preheader\n"
  142. "for.end24: \n"
  143. " ret void\n"
  144. "}\n";
  145. Module &M = parseModule(ModuleString);
  146. Function *F = M.getFunction("foo");
  147. BasicBlock *LoopHeader = F->getEntryBlock().getSingleSuccessor();
  148. auto Plan = buildHCFG(LoopHeader);
  149. VPRegionBlock *TopRegion = cast<VPRegionBlock>(Plan->getEntry());
  150. VPBlockBase *PH = TopRegion->getEntry();
  151. VPBlockBase *H = PH->getSingleSuccessor();
  152. VPBlockBase *OuterIfCmpBlk = H->getSingleSuccessor();
  153. VPBlockBase *InnerIfCmpBlk = OuterIfCmpBlk->getSuccessors()[0];
  154. VPBlockBase *InnerIfTSucc = InnerIfCmpBlk->getSuccessors()[0];
  155. VPBlockBase *InnerIfFSucc = InnerIfCmpBlk->getSuccessors()[1];
  156. VPBlockBase *TSuccSucc = InnerIfTSucc->getSingleSuccessor();
  157. VPBlockBase *FSuccSucc = InnerIfFSucc->getSingleSuccessor();
  158. VPValue *OuterCBV = OuterIfCmpBlk->getCondBit();
  159. VPValue *InnerCBV = InnerIfCmpBlk->getCondBit();
  160. // Apply predication.
  161. VPlanPredicator VPP(*Plan);
  162. VPP.predicate();
  163. VPInstruction *And =
  164. cast<VPInstruction>(InnerIfTSucc->getEntryBasicBlock()->begin());
  165. VPInstruction *Not =
  166. cast<VPInstruction>(InnerIfFSucc->getEntryBasicBlock()->begin());
  167. VPInstruction *NotAnd = cast<VPInstruction>(
  168. &*std::next(InnerIfFSucc->getEntryBasicBlock()->begin(), 1));
  169. VPInstruction *Or =
  170. cast<VPInstruction>(TSuccSucc->getEntryBasicBlock()->begin());
  171. // Test block predicates
  172. EXPECT_NE(nullptr, OuterCBV);
  173. EXPECT_NE(nullptr, InnerCBV);
  174. EXPECT_NE(nullptr, And);
  175. EXPECT_NE(nullptr, Not);
  176. EXPECT_NE(nullptr, NotAnd);
  177. EXPECT_EQ(And->getOpcode(), Instruction::And);
  178. EXPECT_EQ(NotAnd->getOpcode(), Instruction::And);
  179. EXPECT_EQ(Not->getOpcode(), VPInstruction::Not);
  180. EXPECT_EQ(And->getOperand(0), OuterCBV);
  181. EXPECT_EQ(And->getOperand(1), InnerCBV);
  182. EXPECT_EQ(Not->getOperand(0), InnerCBV);
  183. EXPECT_EQ(NotAnd->getOperand(0), OuterCBV);
  184. EXPECT_EQ(NotAnd->getOperand(1), Not);
  185. EXPECT_EQ(InnerIfTSucc->getPredicate(), And);
  186. EXPECT_EQ(InnerIfFSucc->getPredicate(), NotAnd);
  187. EXPECT_EQ(TSuccSucc, FSuccSucc);
  188. EXPECT_EQ(Or->getOpcode(), Instruction::Or);
  189. EXPECT_EQ(TSuccSucc->getPredicate(), Or);
  190. // Test operands of the Or - account for differences in predecessor block
  191. // ordering.
  192. VPInstruction *OrOp0Inst = cast<VPInstruction>(Or->getOperand(0));
  193. VPInstruction *OrOp1Inst = cast<VPInstruction>(Or->getOperand(1));
  194. bool ValidOrOperands = false;
  195. if (((OrOp0Inst == And) && (OrOp1Inst == NotAnd)) ||
  196. ((OrOp0Inst == NotAnd) && (OrOp1Inst == And)))
  197. ValidOrOperands = true;
  198. EXPECT_TRUE(ValidOrOperands);
  199. }
  200. } // namespace
  201. } // namespace llvm