Pārlūkot izejas kodu

[CodeComplete] Propagate preferred types through parser in more cases

Preferred types are used by code completion for ranking. This commit
considerably increases the number of points in code where those types
are propagated.

In order to avoid complicating signatures of Parser's methods, a
preferred type is kept as a member variable in the parser and updated
during parsing.

Differential revision: https://reviews.llvm.org/D56723

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@352788 91177308-0d34-0410-b5e6-96231b3b80d8
Ilya Biryukov 6 gadi atpakaļ
vecāks
revīzija
d60b387098

+ 7 - 0
include/clang/Parse/Parser.h

@@ -74,6 +74,10 @@ class Parser : public CodeCompletionHandler {
   // a statement).
   // a statement).
   SourceLocation PrevTokLocation;
   SourceLocation PrevTokLocation;
 
 
+  /// Tracks an expected type for the current token when parsing an expression.
+  /// Used by code completion for ranking.
+  PreferredTypeBuilder PreferredType;
+
   unsigned short ParenCount = 0, BracketCount = 0, BraceCount = 0;
   unsigned short ParenCount = 0, BracketCount = 0, BraceCount = 0;
   unsigned short MisplacedModuleBeginCount = 0;
   unsigned short MisplacedModuleBeginCount = 0;
 
 
@@ -840,6 +844,7 @@ private:
   ///
   ///
   class TentativeParsingAction {
   class TentativeParsingAction {
     Parser &P;
     Parser &P;
+    PreferredTypeBuilder PrevPreferredType;
     Token PrevTok;
     Token PrevTok;
     size_t PrevTentativelyDeclaredIdentifierCount;
     size_t PrevTentativelyDeclaredIdentifierCount;
     unsigned short PrevParenCount, PrevBracketCount, PrevBraceCount;
     unsigned short PrevParenCount, PrevBracketCount, PrevBraceCount;
@@ -847,6 +852,7 @@ private:
 
 
   public:
   public:
     explicit TentativeParsingAction(Parser& p) : P(p) {
     explicit TentativeParsingAction(Parser& p) : P(p) {
+      PrevPreferredType = P.PreferredType;
       PrevTok = P.Tok;
       PrevTok = P.Tok;
       PrevTentativelyDeclaredIdentifierCount =
       PrevTentativelyDeclaredIdentifierCount =
           P.TentativelyDeclaredIdentifiers.size();
           P.TentativelyDeclaredIdentifiers.size();
@@ -866,6 +872,7 @@ private:
     void Revert() {
     void Revert() {
       assert(isActive && "Parsing action was finished!");
       assert(isActive && "Parsing action was finished!");
       P.PP.Backtrack();
       P.PP.Backtrack();
+      P.PreferredType = PrevPreferredType;
       P.Tok = PrevTok;
       P.Tok = PrevTok;
       P.TentativelyDeclaredIdentifiers.resize(
       P.TentativelyDeclaredIdentifiers.resize(
           PrevTentativelyDeclaredIdentifierCount);
           PrevTentativelyDeclaredIdentifierCount);

+ 1 - 0
include/clang/Sema/CodeCompleteConsumer.h

@@ -380,6 +380,7 @@ public:
   /// if the expression is a variable initializer or a function argument, the
   /// if the expression is a variable initializer or a function argument, the
   /// type of the corresponding variable or function parameter.
   /// type of the corresponding variable or function parameter.
   QualType getPreferredType() const { return PreferredType; }
   QualType getPreferredType() const { return PreferredType; }
+  void setPreferredType(QualType T) { PreferredType = T; }
 
 
   /// Retrieve the type of the base object in a member-access
   /// Retrieve the type of the base object in a member-access
   /// expression.
   /// expression.

+ 41 - 5
include/clang/Sema/Sema.h

@@ -273,6 +273,41 @@ public:
   }
   }
 };
 };
 
 
+/// Keeps track of expected type during expression parsing. The type is tied to
+/// a particular token, all functions that update or consume the type take a
+/// start location of the token they are looking at as a parameter. This allows
+/// to avoid updating the type on hot paths in the parser.
+class PreferredTypeBuilder {
+public:
+  PreferredTypeBuilder() = default;
+  explicit PreferredTypeBuilder(QualType Type) : Type(Type) {}
+
+  void enterCondition(Sema &S, SourceLocation Tok);
+  void enterReturn(Sema &S, SourceLocation Tok);
+  void enterVariableInit(SourceLocation Tok, Decl *D);
+
+  void enterParenExpr(SourceLocation Tok, SourceLocation LParLoc);
+  void enterUnary(Sema &S, SourceLocation Tok, tok::TokenKind OpKind,
+                  SourceLocation OpLoc);
+  void enterBinary(Sema &S, SourceLocation Tok, Expr *LHS, tok::TokenKind Op);
+  void enterMemAccess(Sema &S, SourceLocation Tok, Expr *Base);
+  void enterSubscript(Sema &S, SourceLocation Tok, Expr *LHS);
+  /// Handles all type casts, including C-style cast, C++ casts, etc.
+  void enterTypeCast(SourceLocation Tok, QualType CastType);
+
+  QualType get(SourceLocation Tok) const {
+    if (Tok == ExpectedLoc)
+      return Type;
+    return QualType();
+  }
+
+private:
+  /// Start position of a token for which we store expected type.
+  SourceLocation ExpectedLoc;
+  /// Expected type for a token starting at ExpectedLoc.
+  QualType Type;
+};
+
 /// Sema - This implements semantic analysis and AST building for C.
 /// Sema - This implements semantic analysis and AST building for C.
 class Sema {
 class Sema {
   Sema(const Sema &) = delete;
   Sema(const Sema &) = delete;
@@ -10371,11 +10406,14 @@ public:
   struct CodeCompleteExpressionData;
   struct CodeCompleteExpressionData;
   void CodeCompleteExpression(Scope *S,
   void CodeCompleteExpression(Scope *S,
                               const CodeCompleteExpressionData &Data);
                               const CodeCompleteExpressionData &Data);
-  void CodeCompleteExpression(Scope *S, QualType PreferredType);
+  void CodeCompleteExpression(Scope *S, QualType PreferredType,
+                              bool IsParenthesized = false);
   void CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base, Expr *OtherOpBase,
   void CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base, Expr *OtherOpBase,
                                        SourceLocation OpLoc, bool IsArrow,
                                        SourceLocation OpLoc, bool IsArrow,
-                                       bool IsBaseExprStatement);
-  void CodeCompletePostfixExpression(Scope *S, ExprResult LHS);
+                                       bool IsBaseExprStatement,
+                                       QualType PreferredType);
+  void CodeCompletePostfixExpression(Scope *S, ExprResult LHS,
+                                     QualType PreferredType);
   void CodeCompleteTag(Scope *S, unsigned TagSpec);
   void CodeCompleteTag(Scope *S, unsigned TagSpec);
   void CodeCompleteTypeQualifiers(DeclSpec &DS);
   void CodeCompleteTypeQualifiers(DeclSpec &DS);
   void CodeCompleteFunctionQualifiers(DeclSpec &DS, Declarator &D,
   void CodeCompleteFunctionQualifiers(DeclSpec &DS, Declarator &D,
@@ -10397,9 +10435,7 @@ public:
                                               IdentifierInfo *II,
                                               IdentifierInfo *II,
                                               SourceLocation OpenParLoc);
                                               SourceLocation OpenParLoc);
   void CodeCompleteInitializer(Scope *S, Decl *D);
   void CodeCompleteInitializer(Scope *S, Decl *D);
-  void CodeCompleteReturn(Scope *S);
   void CodeCompleteAfterIf(Scope *S);
   void CodeCompleteAfterIf(Scope *S);
-  void CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op);
 
 
   void CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
   void CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
                                bool EnteringContext, QualType BaseType);
                                bool EnteringContext, QualType BaseType);

+ 2 - 1
lib/Parse/ParseDecl.cpp

@@ -2293,7 +2293,8 @@ Decl *Parser::ParseDeclarationAfterDeclaratorAndAttributes(
         return nullptr;
         return nullptr;
       }
       }
 
 
-      ExprResult Init(ParseInitializer());
+      PreferredType.enterVariableInit(Tok.getLocation(), ThisDecl);
+      ExprResult Init = ParseInitializer();
 
 
       // If this is the only decl in (possibly) range based for statement,
       // If this is the only decl in (possibly) range based for statement,
       // our best guess is that the user meant ':' instead of '='.
       // our best guess is that the user meant ':' instead of '='.

+ 36 - 17
lib/Parse/ParseExpr.cpp

@@ -158,7 +158,8 @@ Parser::ParseExpressionWithLeadingExtension(SourceLocation ExtLoc) {
 /// Parse an expr that doesn't include (top-level) commas.
 /// Parse an expr that doesn't include (top-level) commas.
 ExprResult Parser::ParseAssignmentExpression(TypeCastState isTypeCast) {
 ExprResult Parser::ParseAssignmentExpression(TypeCastState isTypeCast) {
   if (Tok.is(tok::code_completion)) {
   if (Tok.is(tok::code_completion)) {
-    Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(getCurScope(),
+                                   PreferredType.get(Tok.getLocation()));
     cutOffParsing();
     cutOffParsing();
     return ExprError();
     return ExprError();
   }
   }
@@ -271,7 +272,10 @@ Parser::ParseRHSOfBinaryExpression(ExprResult LHS, prec::Level MinPrec) {
                                                getLangOpts().CPlusPlus11);
                                                getLangOpts().CPlusPlus11);
   SourceLocation ColonLoc;
   SourceLocation ColonLoc;
 
 
+  auto SavedType = PreferredType;
   while (1) {
   while (1) {
+    // Every iteration may rely on a preferred type for the whole expression.
+    PreferredType = SavedType;
     // If this token has a lower precedence than we are allowed to parse (e.g.
     // If this token has a lower precedence than we are allowed to parse (e.g.
     // because we are called recursively, or because the token is not a binop),
     // because we are called recursively, or because the token is not a binop),
     // then we are done!
     // then we are done!
@@ -392,15 +396,8 @@ Parser::ParseRHSOfBinaryExpression(ExprResult LHS, prec::Level MinPrec) {
       }
       }
     }
     }
 
 
-    // Code completion for the right-hand side of a binary expression goes
-    // through a special hook that takes the left-hand side into account.
-    if (Tok.is(tok::code_completion)) {
-      Actions.CodeCompleteBinaryRHS(getCurScope(), LHS.get(),
-                                    OpToken.getKind());
-      cutOffParsing();
-      return ExprError();
-    }
-
+    PreferredType.enterBinary(Actions, Tok.getLocation(), LHS.get(),
+                              OpToken.getKind());
     // Parse another leaf here for the RHS of the operator.
     // Parse another leaf here for the RHS of the operator.
     // ParseCastExpression works here because all RHS expressions in C have it
     // ParseCastExpression works here because all RHS expressions in C have it
     // as a prefix, at least. However, in C++, an assignment-expression could
     // as a prefix, at least. However, in C++, an assignment-expression could
@@ -763,6 +760,7 @@ ExprResult Parser::ParseCastExpression(bool isUnaryExpression,
                                        bool isVectorLiteral) {
                                        bool isVectorLiteral) {
   ExprResult Res;
   ExprResult Res;
   tok::TokenKind SavedKind = Tok.getKind();
   tok::TokenKind SavedKind = Tok.getKind();
+  auto SavedType = PreferredType;
   NotCastExpr = false;
   NotCastExpr = false;
 
 
   // This handles all of cast-expression, unary-expression, postfix-expression,
   // This handles all of cast-expression, unary-expression, postfix-expression,
@@ -1114,6 +1112,9 @@ ExprResult Parser::ParseCastExpression(bool isUnaryExpression,
     //     -- cast-expression
     //     -- cast-expression
     Token SavedTok = Tok;
     Token SavedTok = Tok;
     ConsumeToken();
     ConsumeToken();
+
+    PreferredType.enterUnary(Actions, Tok.getLocation(), SavedTok.getKind(),
+                             SavedTok.getLocation());
     // One special case is implicitly handled here: if the preceding tokens are
     // One special case is implicitly handled here: if the preceding tokens are
     // an ambiguous cast expression, such as "(T())++", then we recurse to
     // an ambiguous cast expression, such as "(T())++", then we recurse to
     // determine whether the '++' is prefix or postfix.
     // determine whether the '++' is prefix or postfix.
@@ -1135,6 +1136,7 @@ ExprResult Parser::ParseCastExpression(bool isUnaryExpression,
   case tok::amp: {         // unary-expression: '&' cast-expression
   case tok::amp: {         // unary-expression: '&' cast-expression
     // Special treatment because of member pointers
     // Special treatment because of member pointers
     SourceLocation SavedLoc = ConsumeToken();
     SourceLocation SavedLoc = ConsumeToken();
+    PreferredType.enterUnary(Actions, Tok.getLocation(), tok::amp, SavedLoc);
     Res = ParseCastExpression(false, true);
     Res = ParseCastExpression(false, true);
     if (!Res.isInvalid())
     if (!Res.isInvalid())
       Res = Actions.ActOnUnaryOp(getCurScope(), SavedLoc, SavedKind, Res.get());
       Res = Actions.ActOnUnaryOp(getCurScope(), SavedLoc, SavedKind, Res.get());
@@ -1149,6 +1151,7 @@ ExprResult Parser::ParseCastExpression(bool isUnaryExpression,
   case tok::kw___real:     // unary-expression: '__real' cast-expression [GNU]
   case tok::kw___real:     // unary-expression: '__real' cast-expression [GNU]
   case tok::kw___imag: {   // unary-expression: '__imag' cast-expression [GNU]
   case tok::kw___imag: {   // unary-expression: '__imag' cast-expression [GNU]
     SourceLocation SavedLoc = ConsumeToken();
     SourceLocation SavedLoc = ConsumeToken();
+    PreferredType.enterUnary(Actions, Tok.getLocation(), SavedKind, SavedLoc);
     Res = ParseCastExpression(false);
     Res = ParseCastExpression(false);
     if (!Res.isInvalid())
     if (!Res.isInvalid())
       Res = Actions.ActOnUnaryOp(getCurScope(), SavedLoc, SavedKind, Res.get());
       Res = Actions.ActOnUnaryOp(getCurScope(), SavedLoc, SavedKind, Res.get());
@@ -1423,7 +1426,8 @@ ExprResult Parser::ParseCastExpression(bool isUnaryExpression,
     Res = ParseBlockLiteralExpression();
     Res = ParseBlockLiteralExpression();
     break;
     break;
   case tok::code_completion: {
   case tok::code_completion: {
-    Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(getCurScope(),
+                                   PreferredType.get(Tok.getLocation()));
     cutOffParsing();
     cutOffParsing();
     return ExprError();
     return ExprError();
   }
   }
@@ -1458,6 +1462,7 @@ ExprResult Parser::ParseCastExpression(bool isUnaryExpression,
   // that the address of the function is being taken, which is illegal in CL.
   // that the address of the function is being taken, which is illegal in CL.
 
 
   // These can be followed by postfix-expr pieces.
   // These can be followed by postfix-expr pieces.
+  PreferredType = SavedType;
   Res = ParsePostfixExpressionSuffix(Res);
   Res = ParsePostfixExpressionSuffix(Res);
   if (getLangOpts().OpenCL)
   if (getLangOpts().OpenCL)
     if (Expr *PostfixExpr = Res.get()) {
     if (Expr *PostfixExpr = Res.get()) {
@@ -1497,13 +1502,17 @@ Parser::ParsePostfixExpressionSuffix(ExprResult LHS) {
   // Now that the primary-expression piece of the postfix-expression has been
   // Now that the primary-expression piece of the postfix-expression has been
   // parsed, see if there are any postfix-expression pieces here.
   // parsed, see if there are any postfix-expression pieces here.
   SourceLocation Loc;
   SourceLocation Loc;
+  auto SavedType = PreferredType;
   while (1) {
   while (1) {
+    // Each iteration relies on preferred type for the whole expression.
+    PreferredType = SavedType;
     switch (Tok.getKind()) {
     switch (Tok.getKind()) {
     case tok::code_completion:
     case tok::code_completion:
       if (InMessageExpression)
       if (InMessageExpression)
         return LHS;
         return LHS;
 
 
-      Actions.CodeCompletePostfixExpression(getCurScope(), LHS);
+      Actions.CodeCompletePostfixExpression(
+          getCurScope(), LHS, PreferredType.get(Tok.getLocation()));
       cutOffParsing();
       cutOffParsing();
       return ExprError();
       return ExprError();
 
 
@@ -1545,6 +1554,7 @@ Parser::ParsePostfixExpressionSuffix(ExprResult LHS) {
       Loc = T.getOpenLocation();
       Loc = T.getOpenLocation();
       ExprResult Idx, Length;
       ExprResult Idx, Length;
       SourceLocation ColonLoc;
       SourceLocation ColonLoc;
+      PreferredType.enterSubscript(Actions, Tok.getLocation(), LHS.get());
       if (getLangOpts().CPlusPlus11 && Tok.is(tok::l_brace)) {
       if (getLangOpts().CPlusPlus11 && Tok.is(tok::l_brace)) {
         Diag(Tok, diag::warn_cxx98_compat_generalized_initializer_lists);
         Diag(Tok, diag::warn_cxx98_compat_generalized_initializer_lists);
         Idx = ParseBraceInitializer();
         Idx = ParseBraceInitializer();
@@ -1726,6 +1736,8 @@ Parser::ParsePostfixExpressionSuffix(ExprResult LHS) {
       bool MayBePseudoDestructor = false;
       bool MayBePseudoDestructor = false;
       Expr* OrigLHS = !LHS.isInvalid() ? LHS.get() : nullptr;
       Expr* OrigLHS = !LHS.isInvalid() ? LHS.get() : nullptr;
 
 
+      PreferredType.enterMemAccess(Actions, Tok.getLocation(), OrigLHS);
+
       if (getLangOpts().CPlusPlus && !LHS.isInvalid()) {
       if (getLangOpts().CPlusPlus && !LHS.isInvalid()) {
         Expr *Base = OrigLHS;
         Expr *Base = OrigLHS;
         const Type* BaseType = Base->getType().getTypePtrOrNull();
         const Type* BaseType = Base->getType().getTypePtrOrNull();
@@ -1772,7 +1784,8 @@ Parser::ParsePostfixExpressionSuffix(ExprResult LHS) {
         // Code completion for a member access expression.
         // Code completion for a member access expression.
         Actions.CodeCompleteMemberReferenceExpr(
         Actions.CodeCompleteMemberReferenceExpr(
             getCurScope(), Base, CorrectedBase, OpLoc, OpKind == tok::arrow,
             getCurScope(), Base, CorrectedBase, OpLoc, OpKind == tok::arrow,
-            Base && ExprStatementTokLoc == Base->getBeginLoc());
+            Base && ExprStatementTokLoc == Base->getBeginLoc(),
+            PreferredType.get(Tok.getLocation()));
 
 
         cutOffParsing();
         cutOffParsing();
         return ExprError();
         return ExprError();
@@ -2326,14 +2339,16 @@ Parser::ParseParenExpression(ParenParseOption &ExprType, bool stopIfCastExpr,
     return ExprError();
     return ExprError();
   SourceLocation OpenLoc = T.getOpenLocation();
   SourceLocation OpenLoc = T.getOpenLocation();
 
 
+  PreferredType.enterParenExpr(Tok.getLocation(), OpenLoc);
+
   ExprResult Result(true);
   ExprResult Result(true);
   bool isAmbiguousTypeId;
   bool isAmbiguousTypeId;
   CastTy = nullptr;
   CastTy = nullptr;
 
 
   if (Tok.is(tok::code_completion)) {
   if (Tok.is(tok::code_completion)) {
-    Actions.CodeCompleteOrdinaryName(getCurScope(),
-                 ExprType >= CompoundLiteral? Sema::PCC_ParenthesizedExpression
-                                            : Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(
+        getCurScope(), PreferredType.get(Tok.getLocation()),
+        /*IsParenthesized=*/ExprType >= CompoundLiteral);
     cutOffParsing();
     cutOffParsing();
     return ExprError();
     return ExprError();
   }
   }
@@ -2414,6 +2429,8 @@ Parser::ParseParenExpression(ParenParseOption &ExprType, bool stopIfCastExpr,
     T.consumeClose();
     T.consumeClose();
     ColonProtection.restore();
     ColonProtection.restore();
     RParenLoc = T.getCloseLocation();
     RParenLoc = T.getCloseLocation();
+
+    PreferredType.enterTypeCast(Tok.getLocation(), Ty.get().get());
     ExprResult SubExpr = ParseCastExpression(/*isUnaryExpression=*/false);
     ExprResult SubExpr = ParseCastExpression(/*isUnaryExpression=*/false);
 
 
     if (Ty.isInvalid() || SubExpr.isInvalid())
     if (Ty.isInvalid() || SubExpr.isInvalid())
@@ -2544,6 +2561,7 @@ Parser::ParseParenExpression(ParenParseOption &ExprType, bool stopIfCastExpr,
           return ExprError();
           return ExprError();
         }
         }
 
 
+        PreferredType.enterTypeCast(Tok.getLocation(), CastTy.get());
         // Parse the cast-expression that follows it next.
         // Parse the cast-expression that follows it next.
         // TODO: For cast expression with CastTy.
         // TODO: For cast expression with CastTy.
         Result = ParseCastExpression(/*isUnaryExpression=*/false,
         Result = ParseCastExpression(/*isUnaryExpression=*/false,
@@ -2845,7 +2863,8 @@ bool Parser::ParseExpressionList(SmallVectorImpl<Expr *> &Exprs,
       if (Completer)
       if (Completer)
         Completer();
         Completer();
       else
       else
-        Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+        Actions.CodeCompleteExpression(getCurScope(),
+                                       PreferredType.get(Tok.getLocation()));
       cutOffParsing();
       cutOffParsing();
       return true;
       return true;
     }
     }

+ 4 - 0
lib/Parse/ParseExprCXX.cpp

@@ -1672,6 +1672,8 @@ Parser::ParseCXXTypeConstructExpression(const DeclSpec &DS) {
     BalancedDelimiterTracker T(*this, tok::l_paren);
     BalancedDelimiterTracker T(*this, tok::l_paren);
     T.consumeOpen();
     T.consumeOpen();
 
 
+    PreferredType.enterTypeCast(Tok.getLocation(), TypeRep.get());
+
     ExprVector Exprs;
     ExprVector Exprs;
     CommaLocsTy CommaLocs;
     CommaLocsTy CommaLocs;
 
 
@@ -1739,6 +1741,7 @@ Sema::ConditionResult Parser::ParseCXXCondition(StmtResult *InitStmt,
                                                 Sema::ConditionKind CK,
                                                 Sema::ConditionKind CK,
                                                 ForRangeInfo *FRI) {
                                                 ForRangeInfo *FRI) {
   ParenBraceBracketBalancer BalancerRAIIObj(*this);
   ParenBraceBracketBalancer BalancerRAIIObj(*this);
+  PreferredType.enterCondition(Actions, Tok.getLocation());
 
 
   if (Tok.is(tok::code_completion)) {
   if (Tok.is(tok::code_completion)) {
     Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Condition);
     Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Condition);
@@ -1858,6 +1861,7 @@ Sema::ConditionResult Parser::ParseCXXCondition(StmtResult *InitStmt,
          diag::warn_cxx98_compat_generalized_initializer_lists);
          diag::warn_cxx98_compat_generalized_initializer_lists);
     InitExpr = ParseBraceInitializer();
     InitExpr = ParseBraceInitializer();
   } else if (CopyInitialization) {
   } else if (CopyInitialization) {
+    PreferredType.enterVariableInit(Tok.getLocation(), DeclOut);
     InitExpr = ParseAssignmentExpression();
     InitExpr = ParseAssignmentExpression();
   } else if (Tok.is(tok::l_paren)) {
   } else if (Tok.is(tok::l_paren)) {
     // This was probably an attempt to initialize the variable.
     // This was probably an attempt to initialize the variable.

+ 4 - 1
lib/Parse/ParseStmt.cpp

@@ -1970,9 +1970,12 @@ StmtResult Parser::ParseReturnStatement() {
 
 
   ExprResult R;
   ExprResult R;
   if (Tok.isNot(tok::semi)) {
   if (Tok.isNot(tok::semi)) {
+    if (!IsCoreturn)
+      PreferredType.enterReturn(Actions, Tok.getLocation());
     // FIXME: Code completion for co_return.
     // FIXME: Code completion for co_return.
     if (Tok.is(tok::code_completion) && !IsCoreturn) {
     if (Tok.is(tok::code_completion) && !IsCoreturn) {
-      Actions.CodeCompleteReturn(getCurScope());
+      Actions.CodeCompleteExpression(getCurScope(),
+                                     PreferredType.get(Tok.getLocation()));
       cutOffParsing();
       cutOffParsing();
       return StmtError();
       return StmtError();
     }
     }

+ 199 - 112
lib/Sema/SemaCodeComplete.cpp

@@ -347,6 +347,182 @@ public:
 };
 };
 } // namespace
 } // namespace
 
 
+void PreferredTypeBuilder::enterReturn(Sema &S, SourceLocation Tok) {
+  if (isa<BlockDecl>(S.CurContext)) {
+    if (sema::BlockScopeInfo *BSI = S.getCurBlock()) {
+      Type = BSI->ReturnType;
+      ExpectedLoc = Tok;
+    }
+  } else if (const auto *Function = dyn_cast<FunctionDecl>(S.CurContext)) {
+    Type = Function->getReturnType();
+    ExpectedLoc = Tok;
+  } else if (const auto *Method = dyn_cast<ObjCMethodDecl>(S.CurContext)) {
+    Type = Method->getReturnType();
+    ExpectedLoc = Tok;
+  }
+}
+
+void PreferredTypeBuilder::enterVariableInit(SourceLocation Tok, Decl *D) {
+  auto *VD = llvm::dyn_cast_or_null<ValueDecl>(D);
+  Type = VD ? VD->getType() : QualType();
+  ExpectedLoc = Tok;
+}
+
+void PreferredTypeBuilder::enterParenExpr(SourceLocation Tok,
+                                          SourceLocation LParLoc) {
+  // expected type for parenthesized expression does not change.
+  if (ExpectedLoc == LParLoc)
+    ExpectedLoc = Tok;
+}
+
+static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS,
+                                            tok::TokenKind Op) {
+  if (!LHS)
+    return QualType();
+
+  QualType LHSType = LHS->getType();
+  if (LHSType->isPointerType()) {
+    if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal)
+      return S.getASTContext().getPointerDiffType();
+    // Pointer difference is more common than subtracting an int from a pointer.
+    if (Op == tok::minus)
+      return LHSType;
+  }
+
+  switch (Op) {
+  // No way to infer the type of RHS from LHS.
+  case tok::comma:
+    return QualType();
+  // Prefer the type of the left operand for all of these.
+  // Arithmetic operations.
+  case tok::plus:
+  case tok::plusequal:
+  case tok::minus:
+  case tok::minusequal:
+  case tok::percent:
+  case tok::percentequal:
+  case tok::slash:
+  case tok::slashequal:
+  case tok::star:
+  case tok::starequal:
+  // Assignment.
+  case tok::equal:
+  // Comparison operators.
+  case tok::equalequal:
+  case tok::exclaimequal:
+  case tok::less:
+  case tok::lessequal:
+  case tok::greater:
+  case tok::greaterequal:
+  case tok::spaceship:
+    return LHS->getType();
+  // Binary shifts are often overloaded, so don't try to guess those.
+  case tok::greatergreater:
+  case tok::greatergreaterequal:
+  case tok::lessless:
+  case tok::lesslessequal:
+    if (LHSType->isIntegralOrEnumerationType())
+      return S.getASTContext().IntTy;
+    return QualType();
+  // Logical operators, assume we want bool.
+  case tok::ampamp:
+  case tok::pipepipe:
+  case tok::caretcaret:
+    return S.getASTContext().BoolTy;
+  // Operators often used for bit manipulation are typically used with the type
+  // of the left argument.
+  case tok::pipe:
+  case tok::pipeequal:
+  case tok::caret:
+  case tok::caretequal:
+  case tok::amp:
+  case tok::ampequal:
+    if (LHSType->isIntegralOrEnumerationType())
+      return LHSType;
+    return QualType();
+  // RHS should be a pointer to a member of the 'LHS' type, but we can't give
+  // any particular type here.
+  case tok::periodstar:
+  case tok::arrowstar:
+    return QualType();
+  default:
+    // FIXME(ibiryukov): handle the missing op, re-add the assertion.
+    // assert(false && "unhandled binary op");
+    return QualType();
+  }
+}
+
+/// Get preferred type for an argument of an unary expression. \p ContextType is
+/// preferred type of the whole unary expression.
+static QualType getPreferredTypeOfUnaryArg(Sema &S, QualType ContextType,
+                                           tok::TokenKind Op) {
+  switch (Op) {
+  case tok::exclaim:
+    return S.getASTContext().BoolTy;
+  case tok::amp:
+    if (!ContextType.isNull() && ContextType->isPointerType())
+      return ContextType->getPointeeType();
+    return QualType();
+  case tok::star:
+    if (ContextType.isNull())
+      return QualType();
+    return S.getASTContext().getPointerType(ContextType.getNonReferenceType());
+  case tok::plus:
+  case tok::minus:
+  case tok::tilde:
+  case tok::minusminus:
+  case tok::plusplus:
+    if (ContextType.isNull())
+      return S.getASTContext().IntTy;
+    // leave as is, these operators typically return the same type.
+    return ContextType;
+  case tok::kw___real:
+  case tok::kw___imag:
+    return QualType();
+  default:
+    assert(false && "unhnalded unary op");
+    return QualType();
+  }
+}
+
+void PreferredTypeBuilder::enterBinary(Sema &S, SourceLocation Tok, Expr *LHS,
+                                       tok::TokenKind Op) {
+  Type = getPreferredTypeOfBinaryRHS(S, LHS, Op);
+  ExpectedLoc = Tok;
+}
+
+void PreferredTypeBuilder::enterMemAccess(Sema &S, SourceLocation Tok,
+                                          Expr *Base) {
+  if (!Base)
+    return;
+  Type = this->get(Base->getBeginLoc());
+  ExpectedLoc = Tok;
+}
+
+void PreferredTypeBuilder::enterUnary(Sema &S, SourceLocation Tok,
+                                      tok::TokenKind OpKind,
+                                      SourceLocation OpLoc) {
+  Type = getPreferredTypeOfUnaryArg(S, this->get(OpLoc), OpKind);
+  ExpectedLoc = Tok;
+}
+
+void PreferredTypeBuilder::enterSubscript(Sema &S, SourceLocation Tok,
+                                          Expr *LHS) {
+  Type = S.getASTContext().IntTy;
+  ExpectedLoc = Tok;
+}
+
+void PreferredTypeBuilder::enterTypeCast(SourceLocation Tok,
+                                         QualType CastType) {
+  Type = !CastType.isNull() ? CastType.getCanonicalType() : QualType();
+  ExpectedLoc = Tok;
+}
+
+void PreferredTypeBuilder::enterCondition(Sema &S, SourceLocation Tok) {
+  Type = S.getASTContext().BoolTy;
+  ExpectedLoc = Tok;
+}
+
 class ResultBuilder::ShadowMapEntry::iterator {
 class ResultBuilder::ShadowMapEntry::iterator {
   llvm::PointerUnion<const NamedDecl *, const DeclIndexPair *> DeclOrIterator;
   llvm::PointerUnion<const NamedDecl *, const DeclIndexPair *> DeclOrIterator;
   unsigned SingleDeclIndex;
   unsigned SingleDeclIndex;
@@ -3856,13 +4032,15 @@ void Sema::CodeCompleteDeclSpec(Scope *S, DeclSpec &DS,
 }
 }
 
 
 struct Sema::CodeCompleteExpressionData {
 struct Sema::CodeCompleteExpressionData {
-  CodeCompleteExpressionData(QualType PreferredType = QualType())
+  CodeCompleteExpressionData(QualType PreferredType = QualType(),
+                             bool IsParenthesized = false)
       : PreferredType(PreferredType), IntegralConstantExpression(false),
       : PreferredType(PreferredType), IntegralConstantExpression(false),
-        ObjCCollection(false) {}
+        ObjCCollection(false), IsParenthesized(IsParenthesized) {}
 
 
   QualType PreferredType;
   QualType PreferredType;
   bool IntegralConstantExpression;
   bool IntegralConstantExpression;
   bool ObjCCollection;
   bool ObjCCollection;
+  bool IsParenthesized;
   SmallVector<Decl *, 4> IgnoreDecls;
   SmallVector<Decl *, 4> IgnoreDecls;
 };
 };
 
 
@@ -3873,13 +4051,18 @@ void Sema::CodeCompleteExpression(Scope *S,
   ResultBuilder Results(
   ResultBuilder Results(
       *this, CodeCompleter->getAllocator(),
       *this, CodeCompleter->getAllocator(),
       CodeCompleter->getCodeCompletionTUInfo(),
       CodeCompleter->getCodeCompletionTUInfo(),
-      CodeCompletionContext(CodeCompletionContext::CCC_Expression,
-                            Data.PreferredType));
+      CodeCompletionContext(
+          Data.IsParenthesized
+              ? CodeCompletionContext::CCC_ParenthesizedExpression
+              : CodeCompletionContext::CCC_Expression,
+          Data.PreferredType));
+  auto PCC =
+      Data.IsParenthesized ? PCC_ParenthesizedExpression : PCC_Expression;
   if (Data.ObjCCollection)
   if (Data.ObjCCollection)
     Results.setFilter(&ResultBuilder::IsObjCCollection);
     Results.setFilter(&ResultBuilder::IsObjCCollection);
   else if (Data.IntegralConstantExpression)
   else if (Data.IntegralConstantExpression)
     Results.setFilter(&ResultBuilder::IsIntegralConstantValue);
     Results.setFilter(&ResultBuilder::IsIntegralConstantValue);
-  else if (WantTypesInContext(PCC_Expression, getLangOpts()))
+  else if (WantTypesInContext(PCC, getLangOpts()))
     Results.setFilter(&ResultBuilder::IsOrdinaryName);
     Results.setFilter(&ResultBuilder::IsOrdinaryName);
   else
   else
     Results.setFilter(&ResultBuilder::IsOrdinaryNonTypeName);
     Results.setFilter(&ResultBuilder::IsOrdinaryNonTypeName);
@@ -3897,7 +4080,7 @@ void Sema::CodeCompleteExpression(Scope *S,
                      CodeCompleter->loadExternal());
                      CodeCompleter->loadExternal());
 
 
   Results.EnterNewScope();
   Results.EnterNewScope();
-  AddOrdinaryNameResults(PCC_Expression, S, *this, Results);
+  AddOrdinaryNameResults(PCC, S, *this, Results);
   Results.ExitScope();
   Results.ExitScope();
 
 
   bool PreferredTypeIsPointer = false;
   bool PreferredTypeIsPointer = false;
@@ -3917,13 +4100,16 @@ void Sema::CodeCompleteExpression(Scope *S,
                             Results.data(), Results.size());
                             Results.data(), Results.size());
 }
 }
 
 
-void Sema::CodeCompleteExpression(Scope *S, QualType PreferredType) {
-  return CodeCompleteExpression(S, CodeCompleteExpressionData(PreferredType));
+void Sema::CodeCompleteExpression(Scope *S, QualType PreferredType,
+                                  bool IsParenthesized) {
+  return CodeCompleteExpression(
+      S, CodeCompleteExpressionData(PreferredType, IsParenthesized));
 }
 }
 
 
-void Sema::CodeCompletePostfixExpression(Scope *S, ExprResult E) {
+void Sema::CodeCompletePostfixExpression(Scope *S, ExprResult E,
+                                         QualType PreferredType) {
   if (E.isInvalid())
   if (E.isInvalid())
-    CodeCompleteOrdinaryName(S, PCC_RecoveryInFunction);
+    CodeCompleteExpression(S, PreferredType);
   else if (getLangOpts().ObjC)
   else if (getLangOpts().ObjC)
     CodeCompleteObjCInstanceMessage(S, E.get(), None, false);
     CodeCompleteObjCInstanceMessage(S, E.get(), None, false);
 }
 }
@@ -4211,7 +4397,8 @@ AddRecordMembersCompletionResults(Sema &SemaRef, ResultBuilder &Results,
 void Sema::CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base,
 void Sema::CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base,
                                            Expr *OtherOpBase,
                                            Expr *OtherOpBase,
                                            SourceLocation OpLoc, bool IsArrow,
                                            SourceLocation OpLoc, bool IsArrow,
-                                           bool IsBaseExprStatement) {
+                                           bool IsBaseExprStatement,
+                                           QualType PreferredType) {
   if (!Base || !CodeCompleter)
   if (!Base || !CodeCompleter)
     return;
     return;
 
 
@@ -4239,6 +4426,7 @@ void Sema::CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base,
   }
   }
 
 
   CodeCompletionContext CCContext(contextKind, ConvertedBaseType);
   CodeCompletionContext CCContext(contextKind, ConvertedBaseType);
+  CCContext.setPreferredType(PreferredType);
   ResultBuilder Results(*this, CodeCompleter->getAllocator(),
   ResultBuilder Results(*this, CodeCompleter->getAllocator(),
                         CodeCompleter->getCodeCompletionTUInfo(), CCContext,
                         CodeCompleter->getCodeCompletionTUInfo(), CCContext,
                         &ResultBuilder::IsMember);
                         &ResultBuilder::IsMember);
@@ -4800,22 +4988,6 @@ void Sema::CodeCompleteInitializer(Scope *S, Decl *D) {
   CodeCompleteExpression(S, Data);
   CodeCompleteExpression(S, Data);
 }
 }
 
 
-void Sema::CodeCompleteReturn(Scope *S) {
-  QualType ResultType;
-  if (isa<BlockDecl>(CurContext)) {
-    if (BlockScopeInfo *BSI = getCurBlock())
-      ResultType = BSI->ReturnType;
-  } else if (const auto *Function = dyn_cast<FunctionDecl>(CurContext))
-    ResultType = Function->getReturnType();
-  else if (const auto *Method = dyn_cast<ObjCMethodDecl>(CurContext))
-    ResultType = Method->getReturnType();
-
-  if (ResultType.isNull())
-    CodeCompleteOrdinaryName(S, PCC_Expression);
-  else
-    CodeCompleteExpression(S, ResultType);
-}
-
 void Sema::CodeCompleteAfterIf(Scope *S) {
 void Sema::CodeCompleteAfterIf(Scope *S) {
   ResultBuilder Results(*this, CodeCompleter->getAllocator(),
   ResultBuilder Results(*this, CodeCompleter->getAllocator(),
                         CodeCompleter->getCodeCompletionTUInfo(),
                         CodeCompleter->getCodeCompletionTUInfo(),
@@ -4877,91 +5049,6 @@ void Sema::CodeCompleteAfterIf(Scope *S) {
                             Results.data(), Results.size());
                             Results.data(), Results.size());
 }
 }
 
 
-static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS,
-                                            tok::TokenKind Op) {
-  if (!LHS)
-    return QualType();
-
-  QualType LHSType = LHS->getType();
-  if (LHSType->isPointerType()) {
-    if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal)
-      return S.getASTContext().getPointerDiffType();
-    // Pointer difference is more common than subtracting an int from a pointer.
-    if (Op == tok::minus)
-      return LHSType;
-  }
-
-  switch (Op) {
-  // No way to infer the type of RHS from LHS.
-  case tok::comma:
-    return QualType();
-  // Prefer the type of the left operand for all of these.
-  // Arithmetic operations.
-  case tok::plus:
-  case tok::plusequal:
-  case tok::minus:
-  case tok::minusequal:
-  case tok::percent:
-  case tok::percentequal:
-  case tok::slash:
-  case tok::slashequal:
-  case tok::star:
-  case tok::starequal:
-  // Assignment.
-  case tok::equal:
-  // Comparison operators.
-  case tok::equalequal:
-  case tok::exclaimequal:
-  case tok::less:
-  case tok::lessequal:
-  case tok::greater:
-  case tok::greaterequal:
-  case tok::spaceship:
-    return LHS->getType();
-  // Binary shifts are often overloaded, so don't try to guess those.
-  case tok::greatergreater:
-  case tok::greatergreaterequal:
-  case tok::lessless:
-  case tok::lesslessequal:
-    if (LHSType->isIntegralOrEnumerationType())
-      return S.getASTContext().IntTy;
-    return QualType();
-  // Logical operators, assume we want bool.
-  case tok::ampamp:
-  case tok::pipepipe:
-  case tok::caretcaret:
-    return S.getASTContext().BoolTy;
-  // Operators often used for bit manipulation are typically used with the type
-  // of the left argument.
-  case tok::pipe:
-  case tok::pipeequal:
-  case tok::caret:
-  case tok::caretequal:
-  case tok::amp:
-  case tok::ampequal:
-    if (LHSType->isIntegralOrEnumerationType())
-      return LHSType;
-    return QualType();
-  // RHS should be a pointer to a member of the 'LHS' type, but we can't give
-  // any particular type here.
-  case tok::periodstar:
-  case tok::arrowstar:
-    return QualType();
-  default:
-    // FIXME(ibiryukov): handle the missing op, re-add the assertion.
-    // assert(false && "unhandled binary op");
-    return QualType();
-  }
-}
-
-void Sema::CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op) {
-  auto PreferredType = getPreferredTypeOfBinaryRHS(*this, LHS, Op);
-  if (!PreferredType.isNull())
-    CodeCompleteExpression(S, PreferredType);
-  else
-    CodeCompleteOrdinaryName(S, PCC_Expression);
-}
-
 void Sema::CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
 void Sema::CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
                                    bool EnteringContext, QualType BaseType) {
                                    bool EnteringContext, QualType BaseType) {
   if (SS.isEmpty() || !CodeCompleter)
   if (SS.isEmpty() || !CodeCompleter)

+ 99 - 0
unittests/Sema/CodeCompleteTest.cpp

@@ -339,4 +339,103 @@ TEST(PreferredTypeTest, BinaryExpr) {
   EXPECT_THAT(collectPreferredTypes(Code), Each("NULL TYPE"));
   EXPECT_THAT(collectPreferredTypes(Code), Each("NULL TYPE"));
 }
 }
 
 
+TEST(PreferredTypeTest, Members) {
+  StringRef Code = R"cpp(
+    struct vector {
+      int *begin();
+      vector clone();
+    };
+
+    void test(int *a) {
+      a = ^vector().^clone().^begin();
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
+}
+
+TEST(PreferredTypeTest, Conditions) {
+  StringRef Code = R"cpp(
+    struct vector {
+      bool empty();
+    };
+
+    void test() {
+      if (^vector().^empty()) {}
+      while (^vector().^empty()) {}
+      for (; ^vector().^empty();) {}
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
+}
+
+TEST(PreferredTypeTest, InitAndAssignment) {
+  StringRef Code = R"cpp(
+    struct vector {
+      int* begin();
+    };
+
+    void test() {
+      const int* x = ^vector().^begin();
+      x = ^vector().^begin();
+
+      if (const int* y = ^vector().^begin()) {}
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("const int *"));
+}
+
+TEST(PreferredTypeTest, UnaryExprs) {
+  StringRef Code = R"cpp(
+    void test(long long a) {
+      a = +^a;
+      a = -^a
+      a = ++^a;
+      a = --^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("long long"));
+
+  Code = R"cpp(
+    void test(int a, int *ptr) {
+      !^a;
+      !^ptr;
+      !!!^a;
+
+      a = !^a;
+      a = !^ptr;
+      a = !!!^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
+
+  Code = R"cpp(
+    void test(int a) {
+      const int* x = &^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("const int"));
+
+  Code = R"cpp(
+    void test(int *a) {
+      int x = *^a;
+      int &r = *^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
+
+  Code = R"cpp(
+    void test(int a) {
+      *^a;
+      &^a;
+    }
+
+  )cpp";
+}
+
+TEST(PreferredTypeTest, ParenExpr) {
+  StringRef Code = R"cpp(
+    const int *i = ^(^(^(^10)));
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("const int *"));
+}
 } // namespace
 } // namespace