diff --git a/internal/lsp/completion.go b/internal/lsp/completion.go index 583822d8..ca963b7e 100644 --- a/internal/lsp/completion.go +++ b/internal/lsp/completion.go @@ -5,22 +5,30 @@ package lsp import ( + "fmt" "sort" + "strings" "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" ) -func toProtocolCompletionItems(items []source.CompletionItem) []protocol.CompletionItem { +func toProtocolCompletionItems(items []source.CompletionItem, snippetsSupported, signatureHelpEnabled bool) []protocol.CompletionItem { var results []protocol.CompletionItem sort.Slice(items, func(i, j int) bool { return items[i].Score > items[j].Score }) + insertTextFormat := protocol.PlainTextFormat + if snippetsSupported { + insertTextFormat = protocol.SnippetTextFormat + } for _, item := range items { results = append(results, protocol.CompletionItem{ - Label: item.Label, - Detail: item.Detail, - Kind: float64(toProtocolCompletionItemKind(item.Kind)), + Label: item.Label, + InsertText: labelToProtocolSnippets(item.Label, item.Kind, insertTextFormat, signatureHelpEnabled), + Detail: item.Detail, + Kind: float64(toProtocolCompletionItemKind(item.Kind)), + InsertTextFormat: insertTextFormat, }) } return results @@ -49,5 +57,47 @@ func toProtocolCompletionItemKind(kind source.CompletionItemKind) protocol.Compl default: return protocol.TextCompletion } - +} + +func labelToProtocolSnippets(label string, kind source.CompletionItemKind, insertTextFormat protocol.InsertTextFormat, signatureHelpEnabled bool) string { + switch kind { + case source.ConstantCompletionItem: + // The label for constants is of the format " = ". + // We should now insert the " = " part of the label. + return label[:strings.Index(label, " =")] + case source.FunctionCompletionItem, source.MethodCompletionItem: + trimmed := label[:strings.Index(label, "(")] + params := strings.Trim(label[strings.Index(label, "("):], "()") + if params == "" { + return label + } + // Don't add parameters or parens for the plaintext insert format. + if insertTextFormat == protocol.PlainTextFormat { + return trimmed + } + // If we do have signature help enabled, the user can see parameters as + // they type in the function, so we just return empty parentheses. + if signatureHelpEnabled { + return trimmed + "($1)" + } + // If signature help is not enabled, we should give the user parameters + // that they can tab through. The insert text format follows the + // specification defined by Microsoft for LSP. The "$", "}, and "\" + // characters should be escaped. + r := strings.NewReplacer( + `\`, `\\`, + `}`, `\}`, + `$`, `\$`, + ) + trimmed += "(" + for i, p := range strings.Split(params, ",") { + if i != 0 { + trimmed += ", " + } + trimmed += fmt.Sprintf("${%v:%v}", i+1, r.Replace(strings.Trim(p, " "))) + } + return trimmed + ")" + + } + return label } diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 95762c39..a6dec705 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -31,6 +31,9 @@ func TestLSP(t *testing.T) { func testLSP(t *testing.T, exporter packagestest.Exporter) { const dir = "testdata" + + // We hardcode the expected number of test cases to ensure that all tests + // are being executed. If a test is added, this number must be changed. const expectedCompletionsCount = 43 const expectedDiagnosticsCount = 14 const expectedFormatCount = 3 @@ -52,23 +55,16 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { exported := packagestest.Export(t, exporter, modules) defer exported.Cleanup() - // collect results for certain tests - expectedDiagnostics := make(diagnostics) - completionItems := make(completionItems) - expectedCompletions := make(completions) - expectedFormat := make(formats) - expectedDefinitions := make(definitions) - s := &server{ view: source.NewView(), } - // merge the config objects + // Merge the exported.Config with the view.Config. cfg := *exported.Config cfg.Fset = s.view.Config.Fset cfg.Mode = packages.LoadSyntax s.view.Config = &cfg - // Do a first pass to collect special markers + // Do a first pass to collect special markers for completion. if err := exported.Expect(map[string]interface{}{ "item": func(name string, r packagestest.Range, _, _ string) { exported.Mark(name, r) @@ -76,6 +72,13 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { }); err != nil { t.Fatal(err) } + + expectedDiagnostics := make(diagnostics) + completionItems := make(completionItems) + expectedCompletions := make(completions) + expectedFormat := make(formats) + expectedDefinitions := make(definitions) + // Collect any data that needs to be used by subsequent tests. if err := exported.Expect(map[string]interface{}{ "diag": expectedDiagnostics.collect, @@ -134,66 +137,6 @@ type completions map[token.Position][]token.Pos type formats map[string]string type definitions map[protocol.Location]protocol.Location -func (c completions) test(t *testing.T, exported *packagestest.Exported, s *server, items completionItems) { - for src, itemList := range c { - var want []protocol.CompletionItem - for _, pos := range itemList { - want = append(want, *items[pos]) - } - list, err := s.Completion(context.Background(), &protocol.CompletionParams{ - TextDocumentPositionParams: protocol.TextDocumentPositionParams{ - TextDocument: protocol.TextDocumentIdentifier{ - URI: protocol.DocumentURI(source.ToURI(src.Filename)), - }, - Position: protocol.Position{ - Line: float64(src.Line - 1), - Character: float64(src.Column - 1), - }, - }, - }) - if err != nil { - t.Fatalf("completion failed for %s:%v:%v: %v", filepath.Base(src.Filename), src.Line, src.Column, err) - } - got := list.Items - if equal := reflect.DeepEqual(want, got); !equal { - t.Errorf(diffC(src, want, got)) - } - } -} - -func (c completions) collect(src token.Position, expected []token.Pos) { - c[src] = expected -} - -func (i completionItems) collect(pos token.Pos, label, detail, kind string) { - var k protocol.CompletionItemKind - switch kind { - case "struct": - k = protocol.StructCompletion - case "func": - k = protocol.FunctionCompletion - case "var": - k = protocol.VariableCompletion - case "type": - k = protocol.TypeParameterCompletion - case "field": - k = protocol.FieldCompletion - case "interface": - k = protocol.InterfaceCompletion - case "const": - k = protocol.ConstantCompletion - case "method": - k = protocol.MethodCompletion - case "package": - k = protocol.ModuleCompletion - } - i[pos] = &protocol.CompletionItem{ - Label: label, - Detail: detail, - Kind: float64(k), - } -} - func (d diagnostics) test(t *testing.T, exported *packagestest.Exported, v *source.View) int { count := 0 for filename, want := range d { @@ -241,6 +184,66 @@ func (d diagnostics) collect(pos token.Position, msg string) { d[pos.Filename] = append(d[pos.Filename], want) } +func (c completions) test(t *testing.T, exported *packagestest.Exported, s *server, items completionItems) { + for src, itemList := range c { + var want []protocol.CompletionItem + for _, pos := range itemList { + want = append(want, *items[pos]) + } + list, err := s.Completion(context.Background(), &protocol.CompletionParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: protocol.DocumentURI(source.ToURI(src.Filename)), + }, + Position: protocol.Position{ + Line: float64(src.Line - 1), + Character: float64(src.Column - 1), + }, + }, + }) + if err != nil { + t.Fatalf("completion failed for %s:%v:%v: %v", filepath.Base(src.Filename), src.Line, src.Column, err) + } + got := list.Items + if diff := diffC(src, want, got); diff != "" { + t.Errorf(diff) + } + } +} + +func (c completions) collect(src token.Position, expected []token.Pos) { + c[src] = expected +} + +func (i completionItems) collect(pos token.Pos, label, detail, kind string) { + var k protocol.CompletionItemKind + switch kind { + case "struct": + k = protocol.StructCompletion + case "func": + k = protocol.FunctionCompletion + case "var": + k = protocol.VariableCompletion + case "type": + k = protocol.TypeParameterCompletion + case "field": + k = protocol.FieldCompletion + case "interface": + k = protocol.InterfaceCompletion + case "const": + k = protocol.ConstantCompletion + case "method": + k = protocol.MethodCompletion + case "package": + k = protocol.ModuleCompletion + } + i[pos] = &protocol.CompletionItem{ + Label: label, + Detail: detail, + Kind: float64(k), + } +} + func (f formats) test(t *testing.T, s *server) { for filename, gofmted := range f { edits, err := s.Formatting(context.Background(), &protocol.DocumentFormattingParams{ @@ -313,6 +316,23 @@ func diffD(filename string, want, got []protocol.Diagnostic) string { // diffC prints the diff between expected and actual completion test results. func diffC(pos token.Position, want, got []protocol.CompletionItem) string { + if len(got) != len(want) { + goto Failed + } + for i, w := range want { + g := got[i] + if w.Label != g.Label { + goto Failed + } + if w.Detail != g.Detail { + goto Failed + } + if w.Kind != g.Kind { + goto Failed + } + } + return "" +Failed: msg := &bytes.Buffer{} fmt.Fprintf(msg, "completion failed for %s:%v:%v:\nexpected:\n", filepath.Base(pos.Filename), pos.Line, pos.Column) for _, d := range want { diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 5c9b4f2b..a62183ba 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -30,6 +30,9 @@ type server struct { initializedMu sync.Mutex initialized bool // set once the server has received "initialize" request + signatureHelpEnabled bool + snippetsSupported bool + view *source.View } @@ -40,7 +43,12 @@ func (s *server) Initialize(ctx context.Context, params *protocol.InitializePara return nil, jsonrpc2.NewErrorf(jsonrpc2.CodeInvalidRequest, "server already initialized") } s.view = source.NewView() - s.initialized = true + s.initialized = true // mark server as initialized now + + // Check if the client supports snippets in completion items. + s.snippetsSupported = params.Capabilities.TextDocument.Completion.CompletionItem.SnippetSupport + s.signatureHelpEnabled = true + return &protocol.InitializeResult{ Capabilities: protocol.ServerCapabilities{ CompletionProvider: protocol.CompletionOptions{ @@ -167,7 +175,7 @@ func (s *server) Completion(ctx context.Context, params *protocol.CompletionPara } return &protocol.CompletionList{ IsIncomplete: false, - Items: toProtocolCompletionItems(items), + Items: toProtocolCompletionItems(items, s.snippetsSupported, s.signatureHelpEnabled), }, nil }