RefactoringCallbacks.cpp 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. //===--- RefactoringCallbacks.cpp - Structural query framework ------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "clang/Tooling/RefactoringCallbacks.h"
  13. #include "clang/ASTMatchers/ASTMatchFinder.h"
  14. #include "clang/Basic/SourceLocation.h"
  15. #include "clang/Lex/Lexer.h"
  16. using llvm::StringError;
  17. using llvm::make_error;
  18. namespace clang {
  19. namespace tooling {
  20. RefactoringCallback::RefactoringCallback() {}
  21. tooling::Replacements &RefactoringCallback::getReplacements() {
  22. return Replace;
  23. }
  24. ASTMatchRefactorer::ASTMatchRefactorer(
  25. std::map<std::string, Replacements> &FileToReplaces)
  26. : FileToReplaces(FileToReplaces) {}
  27. void ASTMatchRefactorer::addDynamicMatcher(
  28. const ast_matchers::internal::DynTypedMatcher &Matcher,
  29. RefactoringCallback *Callback) {
  30. MatchFinder.addDynamicMatcher(Matcher, Callback);
  31. Callbacks.push_back(Callback);
  32. }
  33. class RefactoringASTConsumer : public ASTConsumer {
  34. public:
  35. RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
  36. : Refactoring(Refactoring) {}
  37. void HandleTranslationUnit(ASTContext &Context) override {
  38. // The ASTMatchRefactorer is re-used between translation units.
  39. // Clear the matchers so that each Replacement is only emitted once.
  40. for (const auto &Callback : Refactoring.Callbacks) {
  41. Callback->getReplacements().clear();
  42. }
  43. Refactoring.MatchFinder.matchAST(Context);
  44. for (const auto &Callback : Refactoring.Callbacks) {
  45. for (const auto &Replacement : Callback->getReplacements()) {
  46. llvm::Error Err =
  47. Refactoring.FileToReplaces[Replacement.getFilePath()].add(
  48. Replacement);
  49. if (Err) {
  50. llvm::errs() << "Skipping replacement " << Replacement.toString()
  51. << " due to this error:\n"
  52. << toString(std::move(Err)) << "\n";
  53. }
  54. }
  55. }
  56. }
  57. private:
  58. ASTMatchRefactorer &Refactoring;
  59. };
  60. std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
  61. return llvm::make_unique<RefactoringASTConsumer>(*this);
  62. }
  63. static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
  64. StringRef Text) {
  65. return tooling::Replacement(
  66. Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
  67. }
  68. static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
  69. const Stmt &To) {
  70. return replaceStmtWithText(
  71. Sources, From,
  72. Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
  73. Sources, LangOptions()));
  74. }
  75. ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
  76. : FromId(FromId), ToText(ToText) {}
  77. void ReplaceStmtWithText::run(
  78. const ast_matchers::MatchFinder::MatchResult &Result) {
  79. if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) {
  80. auto Err = Replace.add(tooling::Replacement(
  81. *Result.SourceManager,
  82. CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText));
  83. // FIXME: better error handling. For now, just print error message in the
  84. // release version.
  85. if (Err) {
  86. llvm::errs() << llvm::toString(std::move(Err)) << "\n";
  87. assert(false);
  88. }
  89. }
  90. }
  91. ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId)
  92. : FromId(FromId), ToId(ToId) {}
  93. void ReplaceStmtWithStmt::run(
  94. const ast_matchers::MatchFinder::MatchResult &Result) {
  95. const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId);
  96. const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId);
  97. if (FromMatch && ToMatch) {
  98. auto Err = Replace.add(
  99. replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch));
  100. // FIXME: better error handling. For now, just print error message in the
  101. // release version.
  102. if (Err) {
  103. llvm::errs() << llvm::toString(std::move(Err)) << "\n";
  104. assert(false);
  105. }
  106. }
  107. }
  108. ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id,
  109. bool PickTrueBranch)
  110. : Id(Id), PickTrueBranch(PickTrueBranch) {}
  111. void ReplaceIfStmtWithItsBody::run(
  112. const ast_matchers::MatchFinder::MatchResult &Result) {
  113. if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) {
  114. const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse();
  115. if (Body) {
  116. auto Err =
  117. Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
  118. // FIXME: better error handling. For now, just print error message in the
  119. // release version.
  120. if (Err) {
  121. llvm::errs() << llvm::toString(std::move(Err)) << "\n";
  122. assert(false);
  123. }
  124. } else if (!PickTrueBranch) {
  125. // If we want to use the 'else'-branch, but it doesn't exist, delete
  126. // the whole 'if'.
  127. auto Err =
  128. Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, ""));
  129. // FIXME: better error handling. For now, just print error message in the
  130. // release version.
  131. if (Err) {
  132. llvm::errs() << llvm::toString(std::move(Err)) << "\n";
  133. assert(false);
  134. }
  135. }
  136. }
  137. }
  138. ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
  139. llvm::StringRef FromId, std::vector<TemplateElement> &&Template)
  140. : FromId(FromId), Template(Template) {}
  141. llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
  142. ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
  143. std::vector<TemplateElement> ParsedTemplate;
  144. for (size_t Index = 0; Index < ToTemplate.size();) {
  145. if (ToTemplate[Index] == '$') {
  146. if (ToTemplate.substr(Index, 2) == "$$") {
  147. Index += 2;
  148. ParsedTemplate.push_back(
  149. TemplateElement{TemplateElement::Literal, "$"});
  150. } else if (ToTemplate.substr(Index, 2) == "${") {
  151. size_t EndOfIdentifier = ToTemplate.find("}", Index);
  152. if (EndOfIdentifier == std::string::npos) {
  153. return make_error<StringError>(
  154. "Unterminated ${...} in replacement template near " +
  155. ToTemplate.substr(Index),
  156. std::make_error_code(std::errc::bad_message));
  157. }
  158. std::string SourceNodeName =
  159. ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2);
  160. ParsedTemplate.push_back(
  161. TemplateElement{TemplateElement::Identifier, SourceNodeName});
  162. Index = EndOfIdentifier + 1;
  163. } else {
  164. return make_error<StringError>(
  165. "Invalid $ in replacement template near " +
  166. ToTemplate.substr(Index),
  167. std::make_error_code(std::errc::bad_message));
  168. }
  169. } else {
  170. size_t NextIndex = ToTemplate.find('$', Index + 1);
  171. ParsedTemplate.push_back(
  172. TemplateElement{TemplateElement::Literal,
  173. ToTemplate.substr(Index, NextIndex - Index)});
  174. Index = NextIndex;
  175. }
  176. }
  177. return std::unique_ptr<ReplaceNodeWithTemplate>(
  178. new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
  179. }
  180. void ReplaceNodeWithTemplate::run(
  181. const ast_matchers::MatchFinder::MatchResult &Result) {
  182. const auto &NodeMap = Result.Nodes.getMap();
  183. std::string ToText;
  184. for (const auto &Element : Template) {
  185. switch (Element.Type) {
  186. case TemplateElement::Literal:
  187. ToText += Element.Value;
  188. break;
  189. case TemplateElement::Identifier: {
  190. auto NodeIter = NodeMap.find(Element.Value);
  191. if (NodeIter == NodeMap.end()) {
  192. llvm::errs() << "Node " << Element.Value
  193. << " used in replacement template not bound in Matcher \n";
  194. llvm::report_fatal_error("Unbound node in replacement template.");
  195. }
  196. CharSourceRange Source =
  197. CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
  198. ToText += Lexer::getSourceText(Source, *Result.SourceManager,
  199. Result.Context->getLangOpts());
  200. break;
  201. }
  202. }
  203. }
  204. if (NodeMap.count(FromId) == 0) {
  205. llvm::errs() << "Node to be replaced " << FromId
  206. << " not bound in query.\n";
  207. llvm::report_fatal_error("FromId node not bound in MatchResult");
  208. }
  209. auto Replacement =
  210. tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
  211. Result.Context->getLangOpts());
  212. llvm::Error Err = Replace.add(Replacement);
  213. if (Err) {
  214. llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
  215. << "! " << llvm::toString(std::move(Err)) << "\n";
  216. llvm::report_fatal_error("Replacement failed");
  217. }
  218. }
  219. } // end namespace tooling
  220. } // end namespace clang