1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- //===- unittests/AST/ASTPrint.h ------------------------------------------===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- //
- // Helpers to simplify testing of printing of AST constructs provided in the/
- // form of the source code.
- //
- //===----------------------------------------------------------------------===//
- #include "clang/AST/ASTContext.h"
- #include "clang/ASTMatchers/ASTMatchFinder.h"
- #include "clang/Tooling/Tooling.h"
- #include "llvm/ADT/SmallString.h"
- #include "gtest/gtest.h"
- namespace clang {
- using PolicyAdjusterType =
- Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;
- static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
- const Stmt *S, PolicyAdjusterType PolicyAdjuster) {
- assert(S != nullptr && "Expected non-null Stmt");
- PrintingPolicy Policy = Context->getPrintingPolicy();
- if (PolicyAdjuster)
- (*PolicyAdjuster)(Policy);
- S->printPretty(Out, /*Helper*/ nullptr, Policy);
- }
- class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
- SmallString<1024> Printed;
- unsigned NumFoundStmts;
- PolicyAdjusterType PolicyAdjuster;
- public:
- PrintMatch(PolicyAdjusterType PolicyAdjuster)
- : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {}
- void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
- const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id");
- if (!S)
- return;
- NumFoundStmts++;
- if (NumFoundStmts > 1)
- return;
- llvm::raw_svector_ostream Out(Printed);
- PrintStmt(Out, Result.Context, S, PolicyAdjuster);
- }
- StringRef getPrinted() const { return Printed; }
- unsigned getNumFoundStmts() const { return NumFoundStmts; }
- };
- template <typename T>
- ::testing::AssertionResult
- PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
- const T &NodeMatch, StringRef ExpectedPrinted,
- PolicyAdjusterType PolicyAdjuster = None) {
- PrintMatch Printer(PolicyAdjuster);
- ast_matchers::MatchFinder Finder;
- Finder.addMatcher(NodeMatch, &Printer);
- std::unique_ptr<tooling::FrontendActionFactory> Factory(
- tooling::newFrontendActionFactory(&Finder));
- if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args))
- return testing::AssertionFailure()
- << "Parsing error in \"" << Code.str() << "\"";
- if (Printer.getNumFoundStmts() == 0)
- return testing::AssertionFailure() << "Matcher didn't find any statements";
- if (Printer.getNumFoundStmts() > 1)
- return testing::AssertionFailure()
- << "Matcher should match only one statement (found "
- << Printer.getNumFoundStmts() << ")";
- if (Printer.getPrinted() != ExpectedPrinted)
- return ::testing::AssertionFailure()
- << "Expected \"" << ExpectedPrinted.str() << "\", got \""
- << Printer.getPrinted().str() << "\"";
- return ::testing::AssertionSuccess();
- }
- } // namespace clang
|