diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 840966de..422c096b 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -28,6 +28,9 @@ func TestLSP(t *testing.T) { func testLSP(t *testing.T, exporter packagestest.Exporter) { const dir = "testdata" + const expectedCompletionsCount = 4 + const expectedDiagnosticsCount = 7 + const expectedFormatCount = 3 files := packagestest.MustCopyFileTree(dir) for fragment, operation := range files { @@ -48,9 +51,10 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { dirs := make(map[string]bool) // collect results for certain tests - expectedDiagnostics := make(map[string][]protocol.Diagnostic) - expectedCompletions := make(map[token.Position]*protocol.CompletionItem) - expectedFormat := make(map[string]string) + expectedDiagnostics := make(diagnostics) + completionItems := make(completionItems) + expectedCompletions := make(completions) + expectedFormat := make(formats) s := &server{ view: source.NewView(), @@ -81,70 +85,76 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { } // Collect any data that needs to be used by subsequent tests. if err := exported.Expect(map[string]interface{}{ - "diag": func(pos token.Position, msg string) { - collectDiagnostics(t, expectedDiagnostics, pos, msg) - }, - "item": func(pos token.Position, label, detail, kind string) { - collectCompletionItems(expectedCompletions, pos, label, detail, kind) - }, - "format": func(pos token.Position) { - collectFormat(expectedFormat, pos) - }, + "diag": expectedDiagnostics.collect, + "item": completionItems.collect, + "complete": expectedCompletions.collect, + "format": expectedFormat.collect, }); err != nil { t.Fatal(err) } - // test completion - testCompletion(t, exported, s, expectedCompletions) + t.Run("Completion", func(t *testing.T) { + t.Helper() + if len(expectedCompletions) != expectedCompletionsCount { + t.Errorf("got %v completions expected %v", len(expectedCompletions), expectedCompletionsCount) + } + expectedCompletions.test(t, exported, s, completionItems) + }) - // test diagnostics - var dirList []string - for dir := range dirs { - dirList = append(dirList, dir) - } - exported.Config.Mode = packages.LoadFiles - pkgs, err := packages.Load(exported.Config, dirList...) - if err != nil { - t.Fatal(err) - } - testDiagnostics(t, s.view, pkgs, expectedDiagnostics) + t.Run("Diagnostics", func(t *testing.T) { + t.Helper() + diagnosticsCount := expectedDiagnostics.test(t, exported, s.view, dirs) + if diagnosticsCount != expectedDiagnosticsCount { + t.Errorf("got %v diagnostics expected %v", diagnosticsCount, expectedDiagnosticsCount) + } + }) - // test format - testFormat(t, s, expectedFormat) + t.Run("Format", func(t *testing.T) { + t.Helper() + if len(expectedFormat) != expectedFormatCount { + t.Errorf("got %v formats expected %v", len(expectedFormat), expectedFormatCount) + } + expectedFormat.test(t, s) + }) } -func testCompletion(t *testing.T, exported *packagestest.Exported, s *server, wants map[token.Position]*protocol.CompletionItem) { - if err := exported.Expect(map[string]interface{}{ - "complete": func(src token.Position, expected []token.Position) { - var want []protocol.CompletionItem - for _, pos := range expected { - want = append(want, *wants[pos]) - } - list, err := s.Completion(context.Background(), &protocol.CompletionParams{ - TextDocumentPositionParams: protocol.TextDocumentPositionParams{ - TextDocument: protocol.TextDocumentIdentifier{ - URI: protocol.DocumentURI(source.ToURI(src.Filename)), - }, - Position: protocol.Position{ - Line: float64(src.Line - 1), - Character: float64(src.Column - 1), - }, +type diagnostics map[string][]protocol.Diagnostic +type completionItems map[token.Pos]*protocol.CompletionItem +type completions map[token.Position][]token.Pos +type formats map[string]string + +func (c completions) test(t *testing.T, exported *packagestest.Exported, s *server, items completionItems) { + for src, itemList := range c { + var want []protocol.CompletionItem + for _, pos := range itemList { + want = append(want, *items[pos]) + } + list, err := s.Completion(context.Background(), &protocol.CompletionParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: protocol.DocumentURI(source.ToURI(src.Filename)), }, - }) - if err != nil { - t.Fatal(err) - } - got := list.Items - if equal := reflect.DeepEqual(want, got); !equal { - t.Errorf("completion failed for %s:%v:%v: (expected: %v), (got: %v)", filepath.Base(src.Filename), src.Line, src.Column, want, got) - } - }, - }); err != nil { - t.Fatal(err) + Position: protocol.Position{ + Line: float64(src.Line - 1), + Character: float64(src.Column - 1), + }, + }, + }) + if err != nil { + t.Fatal(err) + } + got := list.Items + if equal := reflect.DeepEqual(want, got); !equal { + t.Errorf("completion failed for %s:%v:%v: (expected: %v), (got: %v)", filepath.Base(src.Filename), src.Line, src.Column, want, got) + } } } -func collectCompletionItems(expectedCompletions map[token.Position]*protocol.CompletionItem, pos token.Position, label, detail, kind string) { +func (c completions) collect(src token.Position, expected []token.Pos) { + c[src] = expected +} + +func (i completionItems) collect(pos token.Pos, label, detail, kind string) { var k protocol.CompletionItemKind switch kind { case "struct": @@ -164,14 +174,26 @@ func collectCompletionItems(expectedCompletions map[token.Position]*protocol.Com case "method": k = protocol.MethodCompletion } - expectedCompletions[pos] = &protocol.CompletionItem{ + i[pos] = &protocol.CompletionItem{ Label: label, Detail: detail, Kind: float64(k), } } -func testDiagnostics(t *testing.T, v *source.View, pkgs []*packages.Package, wants map[string][]protocol.Diagnostic) { +func (d diagnostics) test(t *testing.T, exported *packagestest.Exported, v *source.View, dirs map[string]bool) int { + // first trigger a load to get the diagnostics + var dirList []string + for dir := range dirs { + dirList = append(dirList, dir) + } + exported.Config.Mode = packages.LoadFiles + pkgs, err := packages.Load(exported.Config, dirList...) + if err != nil { + t.Fatal(err) + } + // and now see if they match the expected ones + count := 0 for _, pkg := range pkgs { for _, filename := range pkg.GoFiles { f := v.GetFile(source.ToURI(filename)) @@ -183,7 +205,7 @@ func testDiagnostics(t *testing.T, v *source.View, pkgs []*packages.Package, wan sort.Slice(got, func(i int, j int) bool { return got[i].Range.Start.Line < got[j].Range.Start.Line }) - want := wants[filename] + want := d[filename] if equal := reflect.DeepEqual(want, got); !equal { msg := &bytes.Buffer{} fmt.Fprintf(msg, "diagnostics failed for %s: expected:\n", filepath.Base(filename)) @@ -196,11 +218,13 @@ func testDiagnostics(t *testing.T, v *source.View, pkgs []*packages.Package, wan } t.Error(msg.String()) } + count += len(want) } } + return count } -func collectDiagnostics(t *testing.T, expectedDiagnostics map[string][]protocol.Diagnostic, pos token.Position, msg string) { +func (d diagnostics) collect(pos token.Position, msg string) { line := float64(pos.Line - 1) col := float64(pos.Column - 1) want := protocol.Diagnostic{ @@ -218,15 +242,11 @@ func collectDiagnostics(t *testing.T, expectedDiagnostics map[string][]protocol. Source: "LSP", Message: msg, } - if _, ok := expectedDiagnostics[pos.Filename]; ok { - expectedDiagnostics[pos.Filename] = append(expectedDiagnostics[pos.Filename], want) - } else { - t.Errorf("unexpected filename: %v", pos.Filename) - } + d[pos.Filename] = append(d[pos.Filename], want) } -func testFormat(t *testing.T, s *server, expectedFormat map[string]string) { - for filename, gofmted := range expectedFormat { +func (f formats) test(t *testing.T, s *server) { + for filename, gofmted := range f { edits, err := s.Formatting(context.Background(), &protocol.DocumentFormattingParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: protocol.DocumentURI(source.ToURI(filename)), @@ -245,10 +265,10 @@ func testFormat(t *testing.T, s *server, expectedFormat map[string]string) { } } -func collectFormat(expectedFormat map[string]string, pos token.Position) { +func (f formats) collect(pos token.Position) { cmd := exec.Command("gofmt", pos.Filename) stdout := bytes.NewBuffer(nil) cmd.Stdout = stdout cmd.Run() // ignore error, sometimes we have intentionally ungofmt-able files - expectedFormat[pos.Filename] = stdout.String() + f[pos.Filename] = stdout.String() }