diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go index d654397e..8b193b74 100644 --- a/internal/lsp/source/completion.go +++ b/internal/lsp/source/completion.go @@ -619,6 +619,15 @@ func (c *completer) expectedCompositeLiteralType() types.Type { return nil } +// typeModifier represents an operator that changes the expected type. +type typeModifier int + +const ( + dereference typeModifier = iota // dereference ("*") operator + reference // reference ("&") operator + chanRead // channel read ("<-") operator +) + // expectedType returns the expected type for an expression at the query position. func expectedType(c *completer) types.Type { if c.enclosingCompositeLiteral != nil { @@ -626,19 +635,18 @@ func expectedType(c *completer) types.Type { } var ( - derefCount int // count of deref "*" operators - refCount int // count of reference "&" operators - typ types.Type + modifiers []typeModifier + typ types.Type ) Nodes: - for _, node := range c.path { - switch expr := node.(type) { + for i, node := range c.path { + switch node := node.(type) { case *ast.BinaryExpr: // Determine if query position comes from left or right of op. - e := expr.X - if c.pos < expr.OpPos { - e = expr.Y + e := node.X + if c.pos < node.OpPos { + e = node.Y } if tv, ok := c.info.Types[e]; ok { typ = tv.Type @@ -646,12 +654,12 @@ Nodes: } case *ast.AssignStmt: // Only rank completions if you are on the right side of the token. - if c.pos > expr.TokPos { - i := indexExprAtPos(c.pos, expr.Rhs) - if i >= len(expr.Lhs) { - i = len(expr.Lhs) - 1 + if c.pos > node.TokPos { + i := indexExprAtPos(c.pos, node.Rhs) + if i >= len(node.Lhs) { + i = len(node.Lhs) - 1 } - if tv, ok := c.info.Types[expr.Lhs[i]]; ok { + if tv, ok := c.info.Types[node.Lhs[i]]; ok { typ = tv.Type break Nodes } @@ -659,13 +667,13 @@ Nodes: return nil case *ast.CallExpr: // Only consider CallExpr args if position falls between parens. - if expr.Lparen <= c.pos && c.pos <= expr.Rparen { - if tv, ok := c.info.Types[expr.Fun]; ok { + if node.Lparen <= c.pos && c.pos <= node.Rparen { + if tv, ok := c.info.Types[node.Fun]; ok { if sig, ok := tv.Type.(*types.Signature); ok { if sig.Params().Len() == 0 { return nil } - i := indexExprAtPos(c.pos, expr.Args) + i := indexExprAtPos(c.pos, node.Args) // Make sure not to run past the end of expected parameters. if i >= sig.Params().Len() { i = sig.Params().Len() - 1 @@ -678,21 +686,65 @@ Nodes: return nil case *ast.ReturnStmt: if sig := c.enclosingFunction; sig != nil { - // Find signature result that corresponds to our return expression. - if resultIdx := indexExprAtPos(c.pos, expr.Results); resultIdx < len(expr.Results) { + // Find signature result that corresponds to our return statement. + if resultIdx := indexExprAtPos(c.pos, node.Results); resultIdx < len(node.Results) { if resultIdx < sig.Results().Len() { typ = sig.Results().At(resultIdx).Type() break Nodes } } } - + return nil + case *ast.CaseClause: + if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok { + if tv, ok := c.info.Types[swtch.Tag]; ok { + typ = tv.Type + break Nodes + } + } + return nil + case *ast.SliceExpr: + // Make sure position falls within the brackets (e.g. "foo[a:<>]"). + if node.Lbrack < c.pos && c.pos <= node.Rbrack { + typ = types.Typ[types.Int] + break Nodes + } + return nil + case *ast.IndexExpr: + // Make sure position falls within the brackets (e.g. "foo[<>]"). + if node.Lbrack < c.pos && c.pos <= node.Rbrack { + if tv, ok := c.info.Types[node.X]; ok { + switch t := tv.Type.Underlying().(type) { + case *types.Map: + typ = t.Key() + case *types.Slice, *types.Array: + typ = types.Typ[types.Int] + default: + return nil + } + break Nodes + } + } + return nil + case *ast.SendStmt: + // Make sure we are on right side of arrow (e.g. "foo <- <>"). + if c.pos > node.Arrow+1 { + if tv, ok := c.info.Types[node.Chan]; ok { + if ch, ok := tv.Type.Underlying().(*types.Chan); ok { + typ = ch.Elem() + break Nodes + } + } + } return nil case *ast.StarExpr: - derefCount++ + modifiers = append(modifiers, dereference) case *ast.UnaryExpr: - if expr.Op == token.AND { - refCount++ + switch node.Op { + case token.AND: + modifiers = append(modifiers, reference) + case token.ARROW: + modifiers = append(modifiers, chanRead) } default: if breaksExpectedTypeInference(node) { @@ -702,16 +754,17 @@ Nodes: } if typ != nil { - // For every "*" deref operator, add another pointer layer to expected type. - for i := 0; i < derefCount; i++ { - typ = types.NewPointer(typ) - } - // For every "&" ref operator, remove a pointer layer from expected type. - for i := 0; i < refCount; i++ { - if ptr, ok := typ.(*types.Pointer); ok { - typ = ptr.Elem() - } else { - break + for _, mod := range modifiers { + switch mod { + case dereference: + // For every "*" deref operator, add another pointer layer to expected type. + typ = types.NewPointer(typ) + case reference: + // For every "&" ref operator, remove a pointer layer from expected type. + typ = deref(typ) + case chanRead: + // For every "<-" operator, add another layer of channelness. + typ = types.NewChan(types.SendRecv, typ) } } } @@ -719,6 +772,30 @@ Nodes: return typ } +// findSwitchStmt returns an *ast.CaseClause's corresponding *ast.SwitchStmt or +// *ast.TypeSwitchStmt. path should start from the case clause's first ancestor. +func findSwitchStmt(path []ast.Node, pos token.Pos, c *ast.CaseClause) ast.Stmt { + // Make sure position falls within a "case <>:" clause. + if exprAtPos(pos, c.List) == nil { + return nil + } + // A case clause is always nested within a block statement in a switch statement. + if len(path) < 2 { + return nil + } + if _, ok := path[0].(*ast.BlockStmt); !ok { + return nil + } + switch s := path[1].(type) { + case *ast.SwitchStmt: + return s + case *ast.TypeSwitchStmt: + return s + default: + return nil + } +} + // breaksExpectedTypeInference reports if an expression node's type is unrelated // to its child expression node types. For example, "Foo{Bar: x.Baz(<>)}" should // expect a function argument, not a composite literal value. @@ -737,7 +814,7 @@ func breaksExpectedTypeInference(n ast.Node) bool { // func (<>) foo(<>) (<>) {} // func preferTypeNames(path []ast.Node, pos token.Pos) bool { - for _, p := range path { + for i, p := range path { switch n := p.(type) { case *ast.FuncDecl: if r := n.Recv; r != nil && r.Pos() <= pos && pos <= r.End() { @@ -752,6 +829,13 @@ func preferTypeNames(path []ast.Node, pos token.Pos) bool { } } return false + case *ast.CaseClause: + _, isTypeSwitch := findSwitchStmt(path[i+1:], pos, n).(*ast.TypeSwitchStmt) + return isTypeSwitch + case *ast.TypeAssertExpr: + if n.Lparen < pos && pos <= n.Rparen { + return true + } } } return false diff --git a/internal/lsp/testdata/channel/channel.go b/internal/lsp/testdata/channel/channel.go new file mode 100644 index 00000000..a83b8953 --- /dev/null +++ b/internal/lsp/testdata/channel/channel.go @@ -0,0 +1,25 @@ +package channel + +func _() { + var ( + aa = "123" //@item(channelAA, "aa", "string", "var") + ab = 123 //@item(channelAB, "ab", "int", "var") + ) + + { + type myChan chan int + var mc myChan + mc <- a //@complete(" //", channelAB, channelAA) + } + + { + var ac chan int //@item(channelAC, "ac", "chan int", "var") + a <- a //@complete(" <-", channelAC, channelAA, channelAB) + } + + { + var foo chan int //@item(channelFoo, "foo", "chan int", "var") + wantsInt := func(int) {} //@item(channelWantsInt, "wantsInt", "func(int)", "var") + wantsInt(<-) //@complete(")", channelFoo, channelWantsInt, channelAA, channelAB) + } +} diff --git a/internal/lsp/testdata/func_rank/func_rank.go.in b/internal/lsp/testdata/func_rank/func_rank.go.in index d950d3e0..ca983241 100644 --- a/internal/lsp/testdata/func_rank/func_rank.go.in +++ b/internal/lsp/testdata/func_rank/func_rank.go.in @@ -26,8 +26,8 @@ func _() { // no expected type fnInt(func() int { s.A }) //@complete(" }", rankAA, rankAB, rankAC) fnInt(s.A()) //@complete("()", rankAA, rankAB, rankAC) - fnInt([]int{}[s.A]) //@complete("])", rankAA, rankAB, rankAC) - fnInt([]int{}[:s.A]) //@complete("])", rankAA, rankAB, rankAC) + fnInt([]int{}[s.A]) //@complete("])", rankAA, rankAC, rankAB) + fnInt([]int{}[:s.A]) //@complete("])", rankAA, rankAC, rankAB) fnInt(s.A.(int)) //@complete(".(", rankAA, rankAB, rankAC) diff --git a/internal/lsp/testdata/index/index.go b/internal/lsp/testdata/index/index.go new file mode 100644 index 00000000..7e56b511 --- /dev/null +++ b/internal/lsp/testdata/index/index.go @@ -0,0 +1,21 @@ +package index + +func _() { + var ( + aa = "123" //@item(indexAA, "aa", "string", "var") + ab = 123 //@item(indexAB, "ab", "int", "var") + ) + + var foo [1]int + foo[a] //@complete("]", indexAB, indexAA) + foo[:a] //@complete("]", indexAB, indexAA) + a[:a] //@complete("[", indexAA, indexAB) + a[a] //@complete("[", indexAA, indexAB) + + var bar map[string]int + bar[a] //@complete("]", indexAA, indexAB) + + type myMap map[string]int + var baz myMap + baz[a] //@complete("]", indexAA, indexAB) +} diff --git a/internal/lsp/testdata/rank/switch_rank.go.in b/internal/lsp/testdata/rank/switch_rank.go.in new file mode 100644 index 00000000..9e23f6b5 --- /dev/null +++ b/internal/lsp/testdata/rank/switch_rank.go.in @@ -0,0 +1,12 @@ +package rank + +func _() { + switch pear { + case : //@complete(":", pear, apple) + } + + switch pear { + case "hi": + //@complete("", apple, pear) + } +} diff --git a/internal/lsp/testdata/rank/type_assert_rank.go.in b/internal/lsp/testdata/rank/type_assert_rank.go.in new file mode 100644 index 00000000..3490c85b --- /dev/null +++ b/internal/lsp/testdata/rank/type_assert_rank.go.in @@ -0,0 +1,8 @@ +package rank + +func _() { + type flower int //@item(flower, "flower", "int", "type") + var fig string //@item(fig, "fig", "string", "var") + + _ = interface{}(nil).(f) //@complete(") //", flower, fig) +} diff --git a/internal/lsp/testdata/rank/type_switch_rank.go.in b/internal/lsp/testdata/rank/type_switch_rank.go.in new file mode 100644 index 00000000..457c64b8 --- /dev/null +++ b/internal/lsp/testdata/rank/type_switch_rank.go.in @@ -0,0 +1,11 @@ +package rank + +func _() { + type basket int //@item(basket, "basket", "int", "type") + var banana string //@item(banana, "banana", "string", "var") + + switch interface{}(pear).(type) { + case b: //@complete(":", basket, banana) + b //@complete(" //", banana, basket) + } +} diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index 7b064788..d374d70d 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -28,7 +28,7 @@ import ( // We hardcode the expected number of test cases to ensure that all tests // are being executed. If a test is added, this number must be changed. const ( - ExpectedCompletionsCount = 107 + ExpectedCompletionsCount = 121 ExpectedCompletionSnippetCount = 13 ExpectedDiagnosticsCount = 17 ExpectedFormatCount = 5