ASTPrint.h 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. //===- unittests/AST/ASTPrint.h ------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Helpers to simplify testing of printing of AST constructs provided in the/
  10. // form of the source code.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "clang/AST/ASTContext.h"
  14. #include "clang/ASTMatchers/ASTMatchFinder.h"
  15. #include "clang/Tooling/Tooling.h"
  16. #include "llvm/ADT/SmallString.h"
  17. #include "gtest/gtest.h"
  18. namespace clang {
  19. using PolicyAdjusterType =
  20. Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;
  21. static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
  22. const Stmt *S, PolicyAdjusterType PolicyAdjuster) {
  23. assert(S != nullptr && "Expected non-null Stmt");
  24. PrintingPolicy Policy = Context->getPrintingPolicy();
  25. if (PolicyAdjuster)
  26. (*PolicyAdjuster)(Policy);
  27. S->printPretty(Out, /*Helper*/ nullptr, Policy);
  28. }
  29. class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
  30. SmallString<1024> Printed;
  31. unsigned NumFoundStmts;
  32. PolicyAdjusterType PolicyAdjuster;
  33. public:
  34. PrintMatch(PolicyAdjusterType PolicyAdjuster)
  35. : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {}
  36. void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
  37. const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id");
  38. if (!S)
  39. return;
  40. NumFoundStmts++;
  41. if (NumFoundStmts > 1)
  42. return;
  43. llvm::raw_svector_ostream Out(Printed);
  44. PrintStmt(Out, Result.Context, S, PolicyAdjuster);
  45. }
  46. StringRef getPrinted() const { return Printed; }
  47. unsigned getNumFoundStmts() const { return NumFoundStmts; }
  48. };
  49. template <typename T>
  50. ::testing::AssertionResult
  51. PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
  52. const T &NodeMatch, StringRef ExpectedPrinted,
  53. PolicyAdjusterType PolicyAdjuster = None) {
  54. PrintMatch Printer(PolicyAdjuster);
  55. ast_matchers::MatchFinder Finder;
  56. Finder.addMatcher(NodeMatch, &Printer);
  57. std::unique_ptr<tooling::FrontendActionFactory> Factory(
  58. tooling::newFrontendActionFactory(&Finder));
  59. if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args))
  60. return testing::AssertionFailure()
  61. << "Parsing error in \"" << Code.str() << "\"";
  62. if (Printer.getNumFoundStmts() == 0)
  63. return testing::AssertionFailure() << "Matcher didn't find any statements";
  64. if (Printer.getNumFoundStmts() > 1)
  65. return testing::AssertionFailure()
  66. << "Matcher should match only one statement (found "
  67. << Printer.getNumFoundStmts() << ")";
  68. if (Printer.getPrinted() != ExpectedPrinted)
  69. return ::testing::AssertionFailure()
  70. << "Expected \"" << ExpectedPrinted.str() << "\", got \""
  71. << Printer.getPrinted().str() << "\"";
  72. return ::testing::AssertionSuccess();
  73. }
  74. } // namespace clang