diff --git a/go/packages/packagestest/expect.go b/go/packages/packagestest/expect.go index 3a5d88e8..cf4b6738 100644 --- a/go/packages/packagestest/expect.go +++ b/go/packages/packagestest/expect.go @@ -106,11 +106,11 @@ func (e *Exported) Expect(methods map[string]interface{}) error { for i, convert := range mi.converters { params[i], args, err = convert(n, args) if err != nil { - return fmt.Errorf("%v: %v", e.fset.Position(n.Pos), err) + return fmt.Errorf("%v: %v", e.ExpectFileSet.Position(n.Pos), err) } } if len(args) > 0 { - return fmt.Errorf("%v: unwanted args got %+v extra", e.fset.Position(n.Pos), args) + return fmt.Errorf("%v: unwanted args got %+v extra", e.ExpectFileSet.Position(n.Pos), args) } //TODO: catch the error returned from the method mi.f.Call(params) @@ -154,7 +154,7 @@ func (e *Exported) getNotes() error { if err != nil { return err } - l, err := expect.Parse(e.fset, filename, content) + l, err := expect.Parse(e.ExpectFileSet, filename, content) if err != nil { return fmt.Errorf("Failed to extract expectations: %v", err) } @@ -214,7 +214,7 @@ func (e *Exported) buildConverter(pt reflect.Type) (converter, error) { }, nil case pt == fsetType: return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) { - return reflect.ValueOf(e.fset), args, nil + return reflect.ValueOf(e.ExpectFileSet), args, nil }, nil case pt == exportedType: return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) { @@ -234,7 +234,7 @@ func (e *Exported) buildConverter(pt reflect.Type) (converter, error) { if err != nil { return reflect.Value{}, nil, err } - return reflect.ValueOf(e.fset.Position(r.Start)), remains, nil + return reflect.ValueOf(e.ExpectFileSet.Position(r.Start)), remains, nil }, nil case pt == rangeType: return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) { @@ -369,9 +369,9 @@ func (e *Exported) rangeConverter(n *expect.Note, args []interface{}) (span.Rang switch arg { case eofIdentifier: // end of file identifier, look up the current file - f := e.fset.File(n.Pos) + f := e.ExpectFileSet.File(n.Pos) eof := f.Pos(f.Size()) - return span.Range{FileSet: e.fset, Start: eof, End: token.NoPos}, args, nil + return span.Range{FileSet: e.ExpectFileSet, Start: eof, End: token.NoPos}, args, nil default: // look up an marker by name mark, ok := e.markers[string(arg)] @@ -381,23 +381,23 @@ func (e *Exported) rangeConverter(n *expect.Note, args []interface{}) (span.Rang return mark, args, nil } case string: - start, end, err := expect.MatchBefore(e.fset, e.FileContents, n.Pos, arg) + start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg) if err != nil { return span.Range{}, nil, err } if start == token.NoPos { - return span.Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.fset.Position(n.Pos), arg) + return span.Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg) } - return span.Range{FileSet: e.fset, Start: start, End: end}, args, nil + return span.Range{FileSet: e.ExpectFileSet, Start: start, End: end}, args, nil case *regexp.Regexp: - start, end, err := expect.MatchBefore(e.fset, e.FileContents, n.Pos, arg) + start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg) if err != nil { return span.Range{}, nil, err } if start == token.NoPos { - return span.Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.fset.Position(n.Pos), arg) + return span.Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg) } - return span.Range{FileSet: e.fset, Start: start, End: end}, args, nil + return span.Range{FileSet: e.ExpectFileSet, Start: start, End: end}, args, nil default: return span.Range{}, nil, fmt.Errorf("cannot convert %v to pos", arg) } diff --git a/go/packages/packagestest/export.go b/go/packages/packagestest/export.go index e927313e..cb9e6be1 100644 --- a/go/packages/packagestest/export.go +++ b/go/packages/packagestest/export.go @@ -62,10 +62,11 @@ type Exported struct { // Modules is the module description that was used to produce this exported data set. Modules []Module + ExpectFileSet *token.FileSet // The file set used when parsing expectations + temp string // the temporary directory that was exported to primary string // the first non GOROOT module that was exported written map[string]map[string]string // the full set of exported files - fset *token.FileSet // The file set used when parsing expectations notes []*expect.Note // The list of expectations extracted from go source files markers map[string]span.Range // The set of markers extracted from go source files } @@ -140,11 +141,11 @@ func Export(t testing.TB, exporter Exporter, modules []Module) *Exported { Tests: true, Mode: packages.LoadImports, }, - Modules: modules, - temp: temp, - primary: modules[0].Name, - written: map[string]map[string]string{}, - fset: token.NewFileSet(), + Modules: modules, + temp: temp, + primary: modules[0].Name, + written: map[string]map[string]string{}, + ExpectFileSet: token.NewFileSet(), } defer func() { if t.Failed() || t.Skipped() { diff --git a/internal/lsp/cmd/check_test.go b/internal/lsp/cmd/check_test.go index 0665b4b7..4efd7d1e 100644 --- a/internal/lsp/cmd/check_test.go +++ b/internal/lsp/cmd/check_test.go @@ -7,49 +7,38 @@ package cmd_test import ( "context" "fmt" + "runtime" "strings" "testing" - "golang.org/x/tools/go/packages/packagestest" - "golang.org/x/tools/internal/lsp/cmd" - "golang.org/x/tools/internal/lsp/source" + "golang.org/x/tools/internal/lsp/tests" "golang.org/x/tools/internal/span" "golang.org/x/tools/internal/tool" ) -type diagnostics map[string][]source.Diagnostic - -func (l diagnostics) collect(spn span.Span, msgSource, msg string) { - fname, err := spn.URI().Filename() - if err != nil { - return +func (r *runner) Diagnostics(t *testing.T, data tests.Diagnostics) { + if runtime.GOOS != "linux" || isRace { + t.Skip("currently uses too much memory, see issue #31611") } - //TODO: diagnostics with range - spn = span.New(spn.URI(), spn.Start(), span.Point{}) - l[fname] = append(l[fname], source.Diagnostic{ - Span: spn, - Message: msg, - Source: msgSource, - Severity: source.SeverityError, - }) -} - -func (l diagnostics) test(t *testing.T, e *packagestest.Exported) { - count := 0 - for fname, want := range l { + for uri, want := range data { if len(want) == 1 && want[0].Message == "" { continue } + fname, err := uri.Filename() + if err != nil { + t.Fatal(err) + } args := []string{"-remote=internal"} args = append(args, "check", fname) - app := &cmd.Application{} - app.Config = *e.Config out := captureStdOut(t, func() { - tool.Main(context.Background(), app, args) + tool.Main(context.Background(), r.app, args) }) // parse got into a collection of reports got := map[string]struct{}{} for _, l := range strings.Split(out, "\n") { + if len(l) == 0 { + continue + } // parse and reprint to normalize the span bits := strings.SplitN(l, ": ", 2) if len(bits) == 2 { @@ -60,7 +49,8 @@ func (l diagnostics) test(t *testing.T, e *packagestest.Exported) { got[l] = struct{}{} } for _, diag := range want { - expect := fmt.Sprintf("%v: %v", diag.Span, diag.Message) + spn := span.New(diag.Span.URI(), diag.Span.Start(), diag.Span.Start()) + expect := fmt.Sprintf("%v: %v", spn, diag.Message) _, found := got[expect] if !found { t.Errorf("missing diagnostic %q", expect) @@ -71,9 +61,5 @@ func (l diagnostics) test(t *testing.T, e *packagestest.Exported) { for extra, _ := range got { t.Errorf("extra diagnostic %q", extra) } - count += len(want) - } - if count != expectedDiagnosticsCount { - t.Errorf("got %v diagnostics expected %v", count, expectedDiagnosticsCount) } } diff --git a/internal/lsp/cmd/cmd_race_test.go b/internal/lsp/cmd/cmd_race_test.go new file mode 100644 index 00000000..dca3df42 --- /dev/null +++ b/internal/lsp/cmd/cmd_race_test.go @@ -0,0 +1,11 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package cmd_test + +func init() { + isRace = true +} diff --git a/internal/lsp/cmd/cmd_test.go b/internal/lsp/cmd/cmd_test.go index 4151af55..5389664c 100644 --- a/internal/lsp/cmd/cmd_test.go +++ b/internal/lsp/cmd/cmd_test.go @@ -7,166 +7,55 @@ package cmd_test import ( "io/ioutil" "os" - "path/filepath" "strings" "testing" "golang.org/x/tools/go/packages/packagestest" - "golang.org/x/tools/internal/lsp/source" - "golang.org/x/tools/internal/span" + "golang.org/x/tools/internal/lsp/cmd" + "golang.org/x/tools/internal/lsp/tests" ) -// We hardcode the expected number of test cases to ensure that all tests -// are being executed. If a test is added, this number must be changed. -const ( - expectedCompletionsCount = 65 - expectedDiagnosticsCount = 16 - expectedFormatCount = 4 -) +var isRace = false + +type runner struct { + data *tests.Data + app *cmd.Application +} func TestCommandLine(t *testing.T) { packagestest.TestAll(t, testCommandLine) } func testCommandLine(t *testing.T, exporter packagestest.Exporter) { - const dir = "../testdata" + data := tests.Load(t, exporter, "../testdata") + defer data.Exported.Cleanup() - files := packagestest.MustCopyFileTree(dir) - overlays := map[string][]byte{} - for fragment, operation := range files { - if trimmed := strings.TrimSuffix(fragment, ".in"); trimmed != fragment { - delete(files, fragment) - files[trimmed] = operation - } - const overlay = ".overlay" - if index := strings.Index(fragment, overlay); index >= 0 { - delete(files, fragment) - partial := fragment[:index] + fragment[index+len(overlay):] - contents, err := ioutil.ReadFile(filepath.Join(dir, fragment)) - if err != nil { - t.Fatal(err) - } - overlays[partial] = contents - } - } - modules := []packagestest.Module{ - { - Name: "golang.org/x/tools/internal/lsp", - Files: files, - Overlay: overlays, + r := &runner{ + data: data, + app: &cmd.Application{ + Config: *data.Exported.Config, }, } - exported := packagestest.Export(t, exporter, modules) - defer exported.Cleanup() - - // Do a first pass to collect special markers for completion. - if err := exported.Expect(map[string]interface{}{ - "item": func(name string, r packagestest.Range, _, _ string) { - exported.Mark(name, r) - }, - }); err != nil { - t.Fatal(err) - } - - expectedDiagnostics := make(diagnostics) - completionItems := make(completionItems) - expectedCompletions := make(completions) - expectedFormat := make(formats) - expectedDefinitions := make(definitions) - expectedTypeDefinitions := make(definitions) - - // Collect any data that needs to be used by subsequent tests. - if err := exported.Expect(map[string]interface{}{ - "diag": expectedDiagnostics.collect, - "item": completionItems.collect, - "complete": expectedCompletions.collect, - "format": expectedFormat.collect, - "godef": expectedDefinitions.godef, - "definition": expectedDefinitions.definition, - "typdef": expectedTypeDefinitions.typdef, - }); err != nil { - t.Fatal(err) - } - - t.Run("Completion", func(t *testing.T) { - t.Helper() - expectedCompletions.test(t, exported, completionItems) - }) - - t.Run("Diagnostics", func(t *testing.T) { - t.Helper() - expectedDiagnostics.test(t, exported) - }) - - t.Run("Format", func(t *testing.T) { - t.Helper() - expectedFormat.test(t, exported) - }) - - t.Run("Definitions", func(t *testing.T) { - t.Helper() - expectedDefinitions.testDefinitions(t, exported) - }) - - t.Run("TypeDefinitions", func(t *testing.T) { - t.Helper() - expectedTypeDefinitions.testTypeDefinitions(t, exported) - }) + tests.Run(t, r, data) } -type completionItems map[span.Range]*source.CompletionItem -type completions map[span.Span][]span.Span -type formats map[span.URI]span.Span - -func (l completionItems) collect(spn span.Range, label, detail, kind string) { - var k source.CompletionItemKind - switch kind { - case "struct": - k = source.StructCompletionItem - case "func": - k = source.FunctionCompletionItem - case "var": - k = source.VariableCompletionItem - case "type": - k = source.TypeCompletionItem - case "field": - k = source.FieldCompletionItem - case "interface": - k = source.InterfaceCompletionItem - case "const": - k = source.ConstantCompletionItem - case "method": - k = source.MethodCompletionItem - case "package": - k = source.PackageCompletionItem - } - l[spn] = &source.CompletionItem{ - Label: label, - Detail: detail, - Kind: k, - } -} - -func (l completions) collect(src span.Span, expected []span.Span) { - l[src] = expected -} - -func (l completions) test(t *testing.T, e *packagestest.Exported, items completionItems) { - if len(l) != expectedCompletionsCount { - t.Errorf("got %v completions expected %v", len(l), expectedCompletionsCount) - } +func (r *runner) Completion(t *testing.T, data tests.Completions, items tests.CompletionItems) { //TODO: add command line completions tests when it works } -func (l formats) collect(src span.Span) { - l[src.URI()] = src +func (r *runner) Format(t *testing.T, data tests.Formats) { + //TODO: add command line formatting tests when it works } -func (l formats) test(t *testing.T, e *packagestest.Exported) { - if len(l) != expectedFormatCount { - t.Errorf("got %v formats expected %v", len(l), expectedFormatCount) - } - //TODO: add command line formatting tests when it works +func (r *runner) Highlight(t *testing.T, data tests.Highlights) { + //TODO: add command line highlight tests when it works +} +func (r *runner) Symbol(t *testing.T, data tests.Symbols) { + //TODO: add command line symbol tests when it works +} + +func (r *runner) Signature(t *testing.T, data tests.Signatures) { + //TODO: add command line signature tests when it works } func captureStdOut(t testing.TB, f func()) string { diff --git a/internal/lsp/cmd/definition_test.go b/internal/lsp/cmd/definition_test.go index d1388a96..0419b16d 100644 --- a/internal/lsp/cmd/definition_test.go +++ b/internal/lsp/cmd/definition_test.go @@ -17,8 +17,8 @@ import ( "strings" "testing" - "golang.org/x/tools/go/packages/packagestest" "golang.org/x/tools/internal/lsp/cmd" + "golang.org/x/tools/internal/lsp/tests" "golang.org/x/tools/internal/span" "golang.org/x/tools/internal/tool" ) @@ -28,21 +28,15 @@ const ( expectedTypeDefinitionsCount = 2 ) -type definition struct { - src span.Span - flags string - def span.Span - pattern pattern -} - -type definitions map[span.Span]definition - var verifyGuru = flag.Bool("verify-guru", false, "Check that the guru compatability matches") func TestDefinitionHelpExample(t *testing.T) { if runtime.GOOS == "android" { t.Skip("not all source files are available on android") } + if runtime.GOOS != "linux" || isRace { + t.Skip("currently uses too much memory, see issue #31611") + } dir, err := os.Getwd() if err != nil { t.Errorf("could not get wd: %v", err) @@ -64,53 +58,28 @@ func TestDefinitionHelpExample(t *testing.T) { } } -func (l definitions) godef(src, def span.Span) { - l[src] = definition{ - src: src, - def: def, - pattern: newPattern("", def), - } -} - -func (l definitions) typdef(src, def span.Span) { - l[src] = definition{ - src: src, - def: def, - pattern: newPattern("", def), - } -} - -func (l definitions) definition(src span.Span, flags string, def span.Span, match string) { - l[src] = definition{ - src: src, - flags: flags, - def: def, - pattern: newPattern(match, def), - } -} - -func (l definitions) testDefinitions(t *testing.T, e *packagestest.Exported) { - if len(l) != expectedDefinitionsCount { - t.Errorf("got %v definitions expected %v", len(l), expectedDefinitionsCount) - } - for _, d := range l { +func (r *runner) Definition(t *testing.T, data tests.Definitions) { + for _, d := range data { + if d.IsType { + // TODO: support type definition queries + continue + } args := []string{"query"} - if d.flags != "" { - args = append(args, strings.Split(d.flags, " ")...) + if d.Flags != "" { + args = append(args, strings.Split(d.Flags, " ")...) } args = append(args, "definition") - src := span.New(d.src.URI(), span.NewPoint(0, 0, d.src.Start().Offset()), span.Point{}) + src := span.New(d.Src.URI(), span.NewPoint(0, 0, d.Src.Start().Offset()), span.Point{}) args = append(args, fmt.Sprint(src)) - app := &cmd.Application{} - app.Config = *e.Config got := captureStdOut(t, func() { - tool.Main(context.Background(), app, args) + tool.Main(context.Background(), r.app, args) }) - if !d.pattern.matches(got) { - t.Errorf("definition %v\nexpected:\n%s\ngot:\n%s", args, d.pattern, got) + pattern := newPattern(d.Match, d.Def) + if !pattern.matches(got) { + t.Errorf("definition %v\nexpected:\n%s\ngot:\n%s", args, pattern, got) } if *verifyGuru { - moduleMode := e.File(e.Modules[0].Name, "go.mod") != "" + moduleMode := r.data.Exported.File(r.data.Exported.Modules[0].Name, "go.mod") != "" var guruArgs []string runGuru := false if !moduleMode { @@ -133,14 +102,14 @@ func (l definitions) testDefinitions(t *testing.T, e *packagestest.Exported) { } if runGuru { cmd := exec.Command("guru", guruArgs...) - cmd.Env = e.Config.Env + cmd.Env = r.data.Exported.Config.Env out, err := cmd.CombinedOutput() if err != nil { t.Errorf("Could not run guru %v: %v\n%s", guruArgs, err, out) } else { guru := strings.TrimSpace(string(out)) - if !d.pattern.matches(guru) { - t.Errorf("definition %v\nexpected:\n%s\nguru gave:\n%s", args, d.pattern, guru) + if !pattern.matches(guru) { + t.Errorf("definition %v\nexpected:\n%s\nguru gave:\n%s", args, pattern, guru) } } } @@ -148,13 +117,6 @@ func (l definitions) testDefinitions(t *testing.T, e *packagestest.Exported) { } } -func (l definitions) testTypeDefinitions(t *testing.T, e *packagestest.Exported) { - if len(l) != expectedTypeDefinitionsCount { - t.Errorf("got %v definitions expected %v", len(l), expectedTypeDefinitionsCount) - } - //TODO: add command line type definition tests when it works -} - type pattern struct { raw string expanded []string diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 44b39945..b9d0c276 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -8,13 +8,7 @@ import ( "bytes" "context" "fmt" - "go/ast" - "go/parser" "go/token" - "io/ioutil" - "os/exec" - "path/filepath" - "runtime" "sort" "strings" "testing" @@ -24,6 +18,7 @@ import ( "golang.org/x/tools/internal/lsp/diff" "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" + "golang.org/x/tools/internal/lsp/tests" "golang.org/x/tools/internal/lsp/xlog" "golang.org/x/tools/internal/span" ) @@ -32,246 +27,47 @@ func TestLSP(t *testing.T) { packagestest.TestAll(t, testLSP) } +type runner struct { + server *Server + data *tests.Data +} + func testLSP(t *testing.T, exporter packagestest.Exporter) { ctx := context.Background() - const dir = "testdata" - // We hardcode the expected number of test cases to ensure that all tests - // are being executed. If a test is added, this number must be changed. - const expectedCompletionsCount = 65 - const expectedDiagnosticsCount = 16 - const expectedFormatCount = 4 - const expectedDefinitionsCount = 19 - const expectedTypeDefinitionsCount = 2 - const expectedHighlightsCount = 2 - const expectedSymbolsCount = 1 - const expectedSignaturesCount = 19 - - files := packagestest.MustCopyFileTree(dir) - overlays := map[string][]byte{} - for fragment, operation := range files { - if trimmed := strings.TrimSuffix(fragment, ".in"); trimmed != fragment { - delete(files, fragment) - files[trimmed] = operation - } - const overlay = ".overlay" - if index := strings.Index(fragment, overlay); index >= 0 { - delete(files, fragment) - partial := fragment[:index] + fragment[index+len(overlay):] - contents, err := ioutil.ReadFile(filepath.Join(dir, fragment)) - if err != nil { - t.Fatal(err) - } - overlays[partial] = contents - } - } - modules := []packagestest.Module{ - { - Name: "golang.org/x/tools/internal/lsp", - Files: files, - Overlay: overlays, - }, - } - exported := packagestest.Export(t, exporter, modules) - defer exported.Cleanup() - - // Merge the exported.Config with the view.Config. - cfg := *exported.Config - - cfg.Fset = token.NewFileSet() - cfg.Context = context.Background() - cfg.ParseFile = func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - return parser.ParseFile(fset, filename, src, parser.AllErrors|parser.ParseComments) - } + data := tests.Load(t, exporter, "testdata") + defer data.Exported.Cleanup() log := xlog.New(xlog.StdSink{}) - s := &Server{ - views: []*cache.View{cache.NewView(ctx, log, "lsp_test", span.FileURI(cfg.Dir), &cfg)}, - undelivered: make(map[span.URI][]source.Diagnostic), - log: log, - } - // Do a first pass to collect special markers for completion. - if err := exported.Expect(map[string]interface{}{ - "item": func(name string, r packagestest.Range, _, _ string) { - exported.Mark(name, r) + r := &runner{ + server: &Server{ + views: []*cache.View{cache.NewView(ctx, log, "lsp_test", span.FileURI(data.Config.Dir), &data.Config)}, + undelivered: make(map[span.URI][]source.Diagnostic), + log: log, }, - }); err != nil { - t.Fatal(err) + data: data, } - - expectedDiagnostics := make(diagnostics) - completionItems := make(completionItems) - expectedCompletions := make(completions) - expectedFormat := make(formats) - expectedDefinitions := make(definitions) - expectedTypeDefinitions := make(definitions) - expectedHighlights := make(highlights) - expectedSymbols := &symbols{ - m: make(map[span.URI][]protocol.DocumentSymbol), - children: make(map[string][]protocol.DocumentSymbol), - } - expectedSignatures := make(signatures) - - // Collect any data that needs to be used by subsequent tests. - if err := exported.Expect(map[string]interface{}{ - "diag": expectedDiagnostics.collect, - "item": completionItems.collect, - "complete": expectedCompletions.collect, - "format": expectedFormat.collect, - "godef": expectedDefinitions.collect, - "typdef": expectedTypeDefinitions.collect, - "highlight": expectedHighlights.collect, - "symbol": expectedSymbols.collect, - "signature": expectedSignatures.collect, - }); err != nil { - t.Fatal(err) - } - - t.Run("Completion", func(t *testing.T) { - t.Helper() - if len(expectedCompletions) != expectedCompletionsCount { - t.Errorf("got %v completions expected %v", len(expectedCompletions), expectedCompletionsCount) - } - expectedCompletions.test(t, exported, s, completionItems) - }) - - t.Run("Diagnostics", func(t *testing.T) { - t.Helper() - diagnosticsCount := expectedDiagnostics.test(t, s.views[0]) - if diagnosticsCount != expectedDiagnosticsCount { - t.Errorf("got %v diagnostics expected %v", diagnosticsCount, expectedDiagnosticsCount) - } - }) - - t.Run("Format", func(t *testing.T) { - if _, err := exec.LookPath("gofmt"); err != nil { - switch runtime.GOOS { - case "android": - t.Skip("gofmt is not installed") - default: - t.Fatal(err) - } - } - t.Helper() - if len(expectedFormat) != expectedFormatCount { - t.Errorf("got %v formats expected %v", len(expectedFormat), expectedFormatCount) - } - expectedFormat.test(t, s) - }) - - t.Run("Definitions", func(t *testing.T) { - t.Helper() - if len(expectedDefinitions) != expectedDefinitionsCount { - t.Errorf("got %v definitions expected %v", len(expectedDefinitions), expectedDefinitionsCount) - } - expectedDefinitions.test(t, s, false) - }) - - t.Run("TypeDefinitions", func(t *testing.T) { - t.Helper() - if len(expectedTypeDefinitions) != expectedTypeDefinitionsCount { - t.Errorf("got %v type definitions expected %v", len(expectedTypeDefinitions), expectedTypeDefinitionsCount) - } - expectedTypeDefinitions.test(t, s, true) - }) - - t.Run("Highlights", func(t *testing.T) { - t.Helper() - if len(expectedHighlights) != expectedHighlightsCount { - t.Errorf("got %v highlights expected %v", len(expectedHighlights), expectedHighlightsCount) - } - expectedHighlights.test(t, s) - }) - - t.Run("Symbols", func(t *testing.T) { - t.Helper() - if len(expectedSymbols.m) != expectedSymbolsCount { - t.Errorf("got %v symbols expected %v", len(expectedSymbols.m), expectedSymbolsCount) - } - expectedSymbols.test(t, s) - }) - - t.Run("Signatures", func(t *testing.T) { - t.Helper() - if len(expectedSignatures) != expectedSignaturesCount { - t.Errorf("got %v signatures expected %v", len(expectedSignatures), expectedSignaturesCount) - } - expectedSignatures.test(t, s) - }) + tests.Run(t, r, data) } -type diagnostics map[span.URI][]protocol.Diagnostic -type completionItems map[token.Pos]*protocol.CompletionItem -type completions map[token.Position][]token.Pos -type formats map[string]string -type definitions map[protocol.Location]protocol.Location -type highlights map[string][]protocol.Location -type symbols struct { - m map[span.URI][]protocol.DocumentSymbol - children map[string][]protocol.DocumentSymbol -} -type signatures map[token.Position]*protocol.SignatureHelp - -func (d diagnostics) test(t *testing.T, v source.View) int { - count := 0 - ctx := context.Background() - for uri, want := range d { - sourceDiagnostics, err := source.Diagnostics(context.Background(), v, uri) - if err != nil { - t.Fatal(err) - } - got, err := toProtocolDiagnostics(ctx, v, sourceDiagnostics[uri]) +func (r *runner) Diagnostics(t *testing.T, data tests.Diagnostics) { + v := r.server.views[0] + for uri, want := range data { + results, err := source.Diagnostics(context.Background(), v, uri) if err != nil { t.Fatal(err) } + got := results[uri] if diff := diffDiagnostics(uri, want, got); diff != "" { t.Error(diff) } - count += len(want) } - return count } -func (d diagnostics) collect(e *packagestest.Exported, fset *token.FileSet, rng packagestest.Range, msgSource, msg string) { - spn, m := testLocation(e, fset, rng) - if _, ok := d[spn.URI()]; !ok { - d[spn.URI()] = []protocol.Diagnostic{} - } - // If a file has an empty diagnostic message, return. This allows us to - // avoid testing diagnostics in files that may have a lot of them. - if msg == "" { - return - } - severity := protocol.SeverityError - if strings.Contains(string(spn.URI()), "analyzer") { - severity = protocol.SeverityWarning - } - dRng, err := m.Range(spn) - if err != nil { - return - } - want := protocol.Diagnostic{ - Range: dRng, - Severity: severity, - Source: msgSource, - Message: msg, - } - d[spn.URI()] = append(d[spn.URI()], want) -} - -func sortDiagnostics(d []protocol.Diagnostic) { +func sortDiagnostics(d []source.Diagnostic) { sort.Slice(d, func(i int, j int) bool { - if d[i].Range.Start.Line < d[j].Range.Start.Line { - return true - } - if d[i].Range.Start.Line > d[j].Range.Start.Line { - return false - } - if d[i].Range.Start.Character < d[j].Range.Start.Character { - return true - } - if d[i].Range.Start.Character > d[j].Range.Start.Character { - return false + if r := span.Compare(d[i].Span, d[j].Span); r != 0 { + return r < 0 } return d[i].Message < d[j].Message }) @@ -279,7 +75,7 @@ func sortDiagnostics(d []protocol.Diagnostic) { // diffDiagnostics prints the diff between expected and actual diagnostics test // results. -func diffDiagnostics(uri span.URI, want, got []protocol.Diagnostic) string { +func diffDiagnostics(uri span.URI, want, got []source.Diagnostic) string { sortDiagnostics(want) sortDiagnostics(got) if len(got) != len(want) { @@ -290,17 +86,17 @@ func diffDiagnostics(uri span.URI, want, got []protocol.Diagnostic) string { if w.Message != g.Message { return summarizeDiagnostics(i, want, got, "incorrect Message got %v want %v", g.Message, w.Message) } - if w.Range.Start != g.Range.Start { - return summarizeDiagnostics(i, want, got, "incorrect Range.Start got %v want %v", g.Range.Start, w.Range.Start) + if span.ComparePoint(w.Start(), g.Start()) != 0 { + return summarizeDiagnostics(i, want, got, "incorrect Start got %v want %v", g.Start(), w.Start()) } // Special case for diagnostics on parse errors. if strings.Contains(string(uri), "noparse") { - if g.Range.Start != g.Range.End || w.Range.Start != g.Range.End { - return summarizeDiagnostics(i, want, got, "incorrect Range.End got %v want %v", g.Range.End, w.Range.Start) + if span.ComparePoint(g.Start(), g.End()) != 0 || span.ComparePoint(w.Start(), g.End()) != 0 { + return summarizeDiagnostics(i, want, got, "incorrect End got %v want %v", g.End(), w.Start()) } - } else if g.Range.End != g.Range.Start { // Accept any 'want' range if the diagnostic returns a zero-length range. - if w.Range.End != g.Range.End { - return summarizeDiagnostics(i, want, got, "incorrect Range.End got %v want %v", g.Range.End, w.Range.End) + } else if !g.IsPoint() { // Accept any 'want' range if the diagnostic returns a zero-length range. + if span.ComparePoint(w.End(), g.End()) != 0 { + return summarizeDiagnostics(i, want, got, "incorrect End got %v want %v", g.End(), w.End()) } } if w.Severity != g.Severity { @@ -313,7 +109,7 @@ func diffDiagnostics(uri span.URI, want, got []protocol.Diagnostic) string { return "" } -func summarizeDiagnostics(i int, want []protocol.Diagnostic, got []protocol.Diagnostic, reason string, args ...interface{}) string { +func summarizeDiagnostics(i int, want []source.Diagnostic, got []source.Diagnostic, reason string, args ...interface{}) string { msg := &bytes.Buffer{} fmt.Fprint(msg, "diagnostics failed") if i >= 0 { @@ -332,27 +128,27 @@ func summarizeDiagnostics(i int, want []protocol.Diagnostic, got []protocol.Diag return msg.String() } -func (c completions) test(t *testing.T, exported *packagestest.Exported, s *Server, items completionItems) { - for src, itemList := range c { - var want []protocol.CompletionItem +func (r *runner) Completion(t *testing.T, data tests.Completions, items tests.CompletionItems) { + for src, itemList := range data { + var want []source.CompletionItem for _, pos := range itemList { want = append(want, *items[pos]) } - list, err := s.Completion(context.Background(), &protocol.CompletionParams{ + list, err := r.server.Completion(context.Background(), &protocol.CompletionParams{ TextDocumentPositionParams: protocol.TextDocumentPositionParams{ TextDocument: protocol.TextDocumentIdentifier{ - URI: protocol.NewURI(span.FileURI(src.Filename)), + URI: protocol.NewURI(src.URI()), }, Position: protocol.Position{ - Line: float64(src.Line - 1), - Character: float64(src.Column - 1), + Line: float64(src.Start().Line() - 1), + Character: float64(src.Start().Column() - 1), }, }, }) if err != nil { t.Fatal(err) } - wantBuiltins := strings.Contains(src.Filename, "builtins") + wantBuiltins := strings.Contains(string(src.URI()), "builtins") var got []protocol.CompletionItem for _, item := range list.Items { if !wantBuiltins && isBuiltin(item) { @@ -361,7 +157,7 @@ func (c completions) test(t *testing.T, exported *packagestest.Exported, s *Serv got = append(got, item) } if err != nil { - t.Fatalf("completion failed for %s:%v:%v: %v", filepath.Base(src.Filename), src.Line, src.Column, err) + t.Fatalf("completion failed for %v: %v", src, err) } if diff := diffCompletionItems(t, src, want, got); diff != "" { t.Errorf(diff) @@ -388,21 +184,9 @@ func isBuiltin(item protocol.CompletionItem) bool { return false } -func (c completions) collect(src token.Position, expected []token.Pos) { - c[src] = expected -} - -func (i completionItems) collect(pos token.Pos, label, detail, kind string) { - i[pos] = &protocol.CompletionItem{ - Label: label, - Detail: detail, - Kind: protocol.ParseCompletionItemKind(kind), - } -} - // diffCompletionItems prints the diff between expected and actual completion // test results. -func diffCompletionItems(t *testing.T, pos token.Position, want, got []protocol.CompletionItem) string { +func diffCompletionItems(t *testing.T, spn span.Span, want []source.CompletionItem, got []protocol.CompletionItem) string { if len(got) != len(want) { return summarizeCompletionItems(-1, want, got, "different lengths got %v want %v", len(got), len(want)) } @@ -414,14 +198,14 @@ func diffCompletionItems(t *testing.T, pos token.Position, want, got []protocol. if w.Detail != g.Detail { return summarizeCompletionItems(i, want, got, "incorrect Detail got %v want %v", g.Detail, w.Detail) } - if w.Kind != g.Kind { - return summarizeCompletionItems(i, want, got, "incorrect Kind got %v want %v", g.Kind, w.Kind) + if wkind := toProtocolCompletionItemKind(w.Kind); wkind != g.Kind { + return summarizeCompletionItems(i, want, got, "incorrect Kind got %v want %v", g.Kind, wkind) } } return "" } -func summarizeCompletionItems(i int, want []protocol.CompletionItem, got []protocol.CompletionItem, reason string, args ...interface{}) string { +func summarizeCompletionItems(i int, want []source.CompletionItem, got []protocol.CompletionItem, reason string, args ...interface{}) string { msg := &bytes.Buffer{} fmt.Fprint(msg, "completion failed") if i >= 0 { @@ -440,11 +224,11 @@ func summarizeCompletionItems(i int, want []protocol.CompletionItem, got []proto return msg.String() } -func (f formats) test(t *testing.T, s *Server) { +func (r *runner) Format(t *testing.T, data tests.Formats) { ctx := context.Background() - for filename, gofmted := range f { + for filename, gofmted := range data { uri := span.FileURI(filename) - edits, err := s.Formatting(context.Background(), &protocol.DocumentFormattingParams{ + edits, err := r.server.Formatting(context.Background(), &protocol.DocumentFormattingParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: protocol.NewURI(uri), }, @@ -455,7 +239,7 @@ func (f formats) test(t *testing.T, s *Server) { } continue } - _, m, err := newColumnMap(ctx, s.findView(ctx, uri), uri) + _, m, err := newColumnMap(ctx, r.server.findView(ctx, uri), uri) if err != nil { t.Error(err) } @@ -471,74 +255,51 @@ func (f formats) test(t *testing.T, s *Server) { } } -func (f formats) collect(pos token.Position) { - cmd := exec.Command("gofmt", pos.Filename) - stdout := bytes.NewBuffer(nil) - cmd.Stdout = stdout - cmd.Run() // ignore error, sometimes we have intentionally ungofmt-able files - f[pos.Filename] = stdout.String() -} - -func (d definitions) test(t *testing.T, s *Server, typ bool) { - for src, target := range d { +func (r *runner) Definition(t *testing.T, data tests.Definitions) { + for _, d := range data { + sm := r.mapper(d.Src.URI()) + loc, err := sm.Location(d.Src) + if err != nil { + t.Fatalf("failed for %v: %v", d.Src, err) + } params := &protocol.TextDocumentPositionParams{ - TextDocument: protocol.TextDocumentIdentifier{ - URI: src.URI, - }, - Position: src.Range.Start, + TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI}, + Position: loc.Range.Start, } var locs []protocol.Location - var err error - if typ { - locs, err = s.TypeDefinition(context.Background(), params) + if d.IsType { + locs, err = r.server.TypeDefinition(context.Background(), params) } else { - locs, err = s.Definition(context.Background(), params) + locs, err = r.server.Definition(context.Background(), params) } if err != nil { - t.Fatalf("failed for %v: %v", src, err) + t.Fatalf("failed for %v: %v", d.Src, err) } if len(locs) != 1 { t.Errorf("got %d locations for definition, expected 1", len(locs)) } - if locs[0] != target { - t.Errorf("for %v got %v want %v", src, locs[0], target) + locURI := span.NewURI(locs[0].URI) + lm := r.mapper(locURI) + if def, err := lm.Span(locs[0]); err != nil { + t.Fatalf("failed for %v: %v", locs[0], err) + } else if def != d.Def { + t.Errorf("for %v got %v want %v", d.Src, def, d.Def) } } } -func (d definitions) collect(e *packagestest.Exported, fset *token.FileSet, src, target packagestest.Range) { - sSrc, mSrc := testLocation(e, fset, src) - lSrc, err := mSrc.Location(sSrc) - if err != nil { - return - } - sTarget, mTarget := testLocation(e, fset, target) - lTarget, err := mTarget.Location(sTarget) - if err != nil { - return - } - d[lSrc] = lTarget -} - -func (h highlights) collect(e *packagestest.Exported, fset *token.FileSet, name string, rng packagestest.Range) { - s, m := testLocation(e, fset, rng) - loc, err := m.Location(s) - if err != nil { - return - } - - h[name] = append(h[name], loc) -} - -func (h highlights) test(t *testing.T, s *Server) { - for name, locations := range h { +func (r *runner) Highlight(t *testing.T, data tests.Highlights) { + for name, locations := range data { + m := r.mapper(locations[0].URI()) + loc, err := m.Location(locations[0]) + if err != nil { + t.Fatalf("failed for %v: %v", locations[0], err) + } params := &protocol.TextDocumentPositionParams{ - TextDocument: protocol.TextDocumentIdentifier{ - URI: locations[0].URI, - }, - Position: locations[0].Range.Start, + TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI}, + Position: loc.Range.Start, } - highlights, err := s.DocumentHighlight(context.Background(), params) + highlights, err := r.server.DocumentHighlight(context.Background(), params) if err != nil { t.Fatal(err) } @@ -546,55 +307,23 @@ func (h highlights) test(t *testing.T, s *Server) { t.Fatalf("got %d highlights for %s, expected %d", len(highlights), name, len(locations)) } for i := range highlights { - if highlights[i].Range != locations[i].Range { - t.Errorf("want %v, got %v\n", locations[i].Range, highlights[i].Range) + if h, err := m.RangeSpan(highlights[i].Range); err != nil { + t.Fatalf("failed for %v: %v", highlights[i], err) + } else if h != locations[i] { + t.Errorf("want %v, got %v\n", locations[i], h) } } } } -func (s symbols) collect(e *packagestest.Exported, fset *token.FileSet, name string, rng span.Range, kind string, parentName string) { - f := fset.File(rng.Start) - if f == nil { - return - } - - content, err := e.FileContents(f.Name()) - if err != nil { - return - } - - spn, err := rng.Span() - if err != nil { - return - } - - m := protocol.NewColumnMapper(spn.URI(), fset, f, content) - prng, err := m.Range(spn) - if err != nil { - return - } - - sym := protocol.DocumentSymbol{ - Name: name, - Kind: protocol.ParseSymbolKind(kind), - SelectionRange: prng, - } - if parentName == "" { - s.m[spn.URI()] = append(s.m[spn.URI()], sym) - } else { - s.children[parentName] = append(s.children[parentName], sym) - } -} - -func (s symbols) test(t *testing.T, server *Server) { - for uri, expectedSymbols := range s.m { +func (r *runner) Symbol(t *testing.T, data tests.Symbols) { + for uri, expectedSymbols := range data { params := &protocol.DocumentSymbolParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: string(uri), }, } - symbols, err := server.DocumentSymbol(context.Background(), params) + symbols, err := r.server.DocumentSymbol(context.Background(), params) if err != nil { t.Fatal(err) } @@ -603,86 +332,16 @@ func (s symbols) test(t *testing.T, server *Server) { t.Errorf("want %d top-level symbols in %v, got %d", len(expectedSymbols), uri, len(symbols)) continue } - - for i := range expectedSymbols { - children := s.children[expectedSymbols[i].Name] - expectedSymbols[i].Children = children - } - if diff := diffSymbols(uri, expectedSymbols, symbols); diff != "" { + if diff := r.diffSymbols(uri, expectedSymbols, symbols); diff != "" { t.Error(diff) } } } -func (s signatures) collect(src token.Position, signature string, activeParam int64) { - s[src] = &protocol.SignatureHelp{ - Signatures: []protocol.SignatureInformation{{Label: signature}}, - ActiveSignature: 0, - ActiveParameter: float64(activeParam), - } -} - -func diffSignatures(src token.Position, want, got *protocol.SignatureHelp) string { - decorate := func(f string, args ...interface{}) string { - return fmt.Sprintf("Invalid signature at %s: %s", src, fmt.Sprintf(f, args...)) - } - - if lw, lg := len(want.Signatures), len(got.Signatures); lw != lg { - return decorate("wanted %d signatures, got %d", lw, lg) - } - - if want.ActiveSignature != got.ActiveSignature { - return decorate("wanted active signature of %f, got %f", want.ActiveSignature, got.ActiveSignature) - } - - if want.ActiveParameter != got.ActiveParameter { - return decorate("wanted active parameter of %f, got %f", want.ActiveParameter, got.ActiveParameter) - } - - for i := range want.Signatures { - wantSig, gotSig := want.Signatures[i], got.Signatures[i] - - if wantSig.Label != gotSig.Label { - return decorate("wanted label %q, got %q", wantSig.Label, gotSig.Label) - } - - var paramParts []string - for _, p := range gotSig.Parameters { - paramParts = append(paramParts, p.Label) - } - paramsStr := strings.Join(paramParts, ", ") - if !strings.Contains(gotSig.Label, paramsStr) { - return decorate("expected signature %q to contain params %q", gotSig.Label, paramsStr) - } - } - - return "" -} - -func (s signatures) test(t *testing.T, server *Server) { - for src, expectedSignatures := range s { - gotSignatures, err := server.SignatureHelp(context.Background(), &protocol.TextDocumentPositionParams{ - TextDocument: protocol.TextDocumentIdentifier{ - URI: protocol.NewURI(span.FileURI(src.Filename)), - }, - Position: protocol.Position{ - Line: float64(src.Line - 1), - Character: float64(src.Column - 1), - }, - }) - if err != nil { - t.Fatal(err) - } - - if diff := diffSignatures(src, expectedSignatures, gotSignatures); diff != "" { - t.Error(diff) - } - } -} - -func diffSymbols(uri span.URI, want, got []protocol.DocumentSymbol) string { +func (r *runner) diffSymbols(uri span.URI, want []source.Symbol, got []protocol.DocumentSymbol) string { sort.Slice(want, func(i, j int) bool { return want[i].Name < want[j].Name }) sort.Slice(got, func(i, j int) bool { return got[i].Name < got[j].Name }) + m := r.mapper(uri) if len(got) != len(want) { return summarizeSymbols(-1, want, got, "different lengths got %v want %v", len(got), len(want)) } @@ -691,20 +350,24 @@ func diffSymbols(uri span.URI, want, got []protocol.DocumentSymbol) string { if w.Name != g.Name { return summarizeSymbols(i, want, got, "incorrect name got %v want %v", g.Name, w.Name) } - if w.Kind != g.Kind { - return summarizeSymbols(i, want, got, "incorrect kind got %v want %v", g.Kind, w.Kind) + if wkind := toProtocolSymbolKind(w.Kind); wkind != g.Kind { + return summarizeSymbols(i, want, got, "incorrect kind got %v want %v", g.Kind, wkind) } - if w.SelectionRange != g.SelectionRange { - return summarizeSymbols(i, want, got, "incorrect span got %v want %v", g.SelectionRange, w.SelectionRange) + spn, err := m.RangeSpan(g.SelectionRange) + if err != nil { + return summarizeSymbols(i, want, got, "%v", err) } - if msg := diffSymbols(uri, w.Children, g.Children); msg != "" { + if w.SelectionSpan != spn { + return summarizeSymbols(i, want, got, "incorrect span got %v want %v", spn, w.SelectionSpan) + } + if msg := r.diffSymbols(uri, w.Children, g.Children); msg != "" { return fmt.Sprintf("children of %s: %s", w.Name, msg) } } return "" } -func summarizeSymbols(i int, want []protocol.DocumentSymbol, got []protocol.DocumentSymbol, reason string, args ...interface{}) string { +func summarizeSymbols(i int, want []source.Symbol, got []protocol.DocumentSymbol, reason string, args ...interface{}) string { msg := &bytes.Buffer{} fmt.Fprint(msg, "document symbols failed") if i >= 0 { @@ -714,7 +377,7 @@ func summarizeSymbols(i int, want []protocol.DocumentSymbol, got []protocol.Docu fmt.Fprintf(msg, reason, args...) fmt.Fprint(msg, ":\nexpected:\n") for _, s := range want { - fmt.Fprintf(msg, " %v %v %v\n", s.Name, s.Kind, s.SelectionRange) + fmt.Fprintf(msg, " %v %v %v\n", s.Name, s.Kind, s.SelectionSpan) } fmt.Fprintf(msg, "got:\n") for _, s := range got { @@ -723,18 +386,86 @@ func summarizeSymbols(i int, want []protocol.DocumentSymbol, got []protocol.Docu return msg.String() } -func testLocation(e *packagestest.Exported, fset *token.FileSet, rng packagestest.Range) (span.Span, *protocol.ColumnMapper) { - spn, err := span.NewRange(fset, rng.Start, rng.End).Span() - if err != nil { - return spn, nil +func (r *runner) Signature(t *testing.T, data tests.Signatures) { + for spn, expectedSignatures := range data { + m := r.mapper(spn.URI()) + loc, err := m.Location(spn) + if err != nil { + t.Fatalf("failed for %v: %v", loc, err) + } + gotSignatures, err := r.server.SignatureHelp(context.Background(), &protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: protocol.NewURI(spn.URI()), + }, + Position: loc.Range.Start, + }) + if err != nil { + t.Fatal(err) + } + + if diff := diffSignatures(spn, expectedSignatures, gotSignatures); diff != "" { + t.Error(diff) + } } - f := fset.File(rng.Start) - content, err := e.FileContents(f.Name()) - if err != nil { - return spn, nil +} + +func diffSignatures(spn span.Span, want source.SignatureInformation, got *protocol.SignatureHelp) string { + decorate := func(f string, args ...interface{}) string { + return fmt.Sprintf("Invalid signature at %s: %s", spn, fmt.Sprintf(f, args...)) } - m := protocol.NewColumnMapper(spn.URI(), fset, f, content) - return spn, m + + if len(got.Signatures) != 1 { + return decorate("wanted 1 signature, got %d", len(got.Signatures)) + } + + if got.ActiveSignature != 0 { + return decorate("wanted active signature of 0, got %f", got.ActiveSignature) + } + + if want.ActiveParameter != int(got.ActiveParameter) { + return decorate("wanted active parameter of %d, got %f", want.ActiveParameter, got.ActiveParameter) + } + + gotSig := got.Signatures[int(got.ActiveSignature)] + + if want.Label != gotSig.Label { + return decorate("wanted label %q, got %q", want.Label, gotSig.Label) + } + + var paramParts []string + for _, p := range gotSig.Parameters { + paramParts = append(paramParts, p.Label) + } + paramsStr := strings.Join(paramParts, ", ") + if !strings.Contains(gotSig.Label, paramsStr) { + return decorate("expected signature %q to contain params %q", gotSig.Label, paramsStr) + } + + return "" +} + +func (r *runner) mapper(uri span.URI) *protocol.ColumnMapper { + fname, err := uri.Filename() + if err != nil { + return nil + } + fset := r.data.Exported.ExpectFileSet + var f *token.File + fset.Iterate(func(check *token.File) bool { + if check.Name() == fname { + f = check + return false + } + return true + }) + if f == nil { + return nil + } + content, err := r.data.Exported.FileContents(f.Name()) + if err != nil { + return nil + } + return protocol.NewColumnMapper(uri, fset, f, content) } func TestBytesOffset(t *testing.T) { diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go new file mode 100644 index 00000000..10012c84 --- /dev/null +++ b/internal/lsp/tests/tests.go @@ -0,0 +1,313 @@ +// Copyright 2019q The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tests + +import ( + "bytes" + "context" + "go/ast" + "go/parser" + "go/token" + "io/ioutil" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + + "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/packages/packagestest" + "golang.org/x/tools/internal/lsp/source" + "golang.org/x/tools/internal/span" +) + +// We hardcode the expected number of test cases to ensure that all tests +// are being executed. If a test is added, this number must be changed. +const ( + ExpectedCompletionsCount = 65 + ExpectedDiagnosticsCount = 16 + ExpectedFormatCount = 4 + ExpectedDefinitionsCount = 21 + ExpectedTypeDefinitionsCount = 2 + ExpectedHighlightsCount = 2 + ExpectedSymbolsCount = 1 + ExpectedSignaturesCount = 19 +) + +type Diagnostics map[span.URI][]source.Diagnostic +type CompletionItems map[token.Pos]*source.CompletionItem +type Completions map[span.Span][]token.Pos +type Formats map[string]string +type Definitions map[span.Span]Definition +type Highlights map[string][]span.Span +type Symbols map[span.URI][]source.Symbol +type SymbolsChildren map[string][]source.Symbol +type Signatures map[span.Span]source.SignatureInformation + +type Data struct { + Config packages.Config + Exported *packagestest.Exported + Diagnostics Diagnostics + CompletionItems CompletionItems + Completions Completions + Formats Formats + Definitions Definitions + Highlights Highlights + Symbols Symbols + symbolsChildren SymbolsChildren + Signatures Signatures +} + +type Tests interface { + Diagnostics(*testing.T, Diagnostics) + Completion(*testing.T, Completions, CompletionItems) + Format(*testing.T, Formats) + Definition(*testing.T, Definitions) + Highlight(*testing.T, Highlights) + Symbol(*testing.T, Symbols) + Signature(*testing.T, Signatures) +} + +type Definition struct { + Src span.Span + IsType bool + Flags string + Def span.Span + Match string +} + +func Load(t testing.TB, exporter packagestest.Exporter, dir string) *Data { + t.Helper() + + data := &Data{ + Diagnostics: make(Diagnostics), + CompletionItems: make(CompletionItems), + Completions: make(Completions), + Formats: make(Formats), + Definitions: make(Definitions), + Highlights: make(Highlights), + Symbols: make(Symbols), + symbolsChildren: make(SymbolsChildren), + Signatures: make(Signatures), + } + + files := packagestest.MustCopyFileTree(dir) + overlays := map[string][]byte{} + for fragment, operation := range files { + if trimmed := strings.TrimSuffix(fragment, ".in"); trimmed != fragment { + delete(files, fragment) + files[trimmed] = operation + } + const overlay = ".overlay" + if index := strings.Index(fragment, overlay); index >= 0 { + delete(files, fragment) + partial := fragment[:index] + fragment[index+len(overlay):] + contents, err := ioutil.ReadFile(filepath.Join(dir, fragment)) + if err != nil { + t.Fatal(err) + } + overlays[partial] = contents + } + } + modules := []packagestest.Module{ + { + Name: "golang.org/x/tools/internal/lsp", + Files: files, + Overlay: overlays, + }, + } + data.Exported = packagestest.Export(t, exporter, modules) + + // Merge the exported.Config with the view.Config. + data.Config = *data.Exported.Config + data.Config.Fset = token.NewFileSet() + data.Config.Context = context.Background() + data.Config.ParseFile = func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.AllErrors|parser.ParseComments) + } + + // Do a first pass to collect special markers for completion. + if err := data.Exported.Expect(map[string]interface{}{ + "item": func(name string, r packagestest.Range, _, _ string) { + data.Exported.Mark(name, r) + }, + }); err != nil { + t.Fatal(err) + } + + // Collect any data that needs to be used by subsequent tests. + if err := data.Exported.Expect(map[string]interface{}{ + "diag": data.collectDiagnostics, + "item": data.collectCompletionItems, + "complete": data.collectCompletions, + "format": data.collectFormats, + "godef": data.collectDefinitions, + "typdef": data.collectTypeDefinitions, + "highlight": data.collectHighlights, + "symbol": data.collectSymbols, + "signature": data.collectSignatures, + }); err != nil { + t.Fatal(err) + } + for _, symbols := range data.Symbols { + for i := range symbols { + children := data.symbolsChildren[symbols[i].Name] + symbols[i].Children = children + } + } + return data +} + +func Run(t *testing.T, tests Tests, data *Data) { + t.Helper() + t.Run("Completion", func(t *testing.T) { + t.Helper() + if len(data.Completions) != ExpectedCompletionsCount { + t.Errorf("got %v completions expected %v", len(data.Completions), ExpectedCompletionsCount) + } + tests.Completion(t, data.Completions, data.CompletionItems) + }) + + t.Run("Diagnostics", func(t *testing.T) { + t.Helper() + diagnosticsCount := 0 + for _, want := range data.Diagnostics { + diagnosticsCount += len(want) + } + if diagnosticsCount != ExpectedDiagnosticsCount { + t.Errorf("got %v diagnostics expected %v", diagnosticsCount, ExpectedDiagnosticsCount) + } + tests.Diagnostics(t, data.Diagnostics) + }) + + t.Run("Format", func(t *testing.T) { + t.Helper() + if _, err := exec.LookPath("gofmt"); err != nil { + switch runtime.GOOS { + case "android": + t.Skip("gofmt is not installed") + default: + t.Fatal(err) + } + } + if len(data.Formats) != ExpectedFormatCount { + t.Errorf("got %v formats expected %v", len(data.Formats), ExpectedFormatCount) + } + tests.Format(t, data.Formats) + }) + + t.Run("Definitions", func(t *testing.T) { + t.Helper() + if len(data.Definitions) != ExpectedDefinitionsCount { + t.Errorf("got %v definitions expected %v", len(data.Definitions), ExpectedDefinitionsCount) + } + tests.Definition(t, data.Definitions) + }) + + t.Run("Highlights", func(t *testing.T) { + t.Helper() + if len(data.Highlights) != ExpectedHighlightsCount { + t.Errorf("got %v highlights expected %v", len(data.Highlights), ExpectedHighlightsCount) + } + tests.Highlight(t, data.Highlights) + }) + + t.Run("Symbols", func(t *testing.T) { + t.Helper() + if len(data.Symbols) != ExpectedSymbolsCount { + t.Errorf("got %v symbols expected %v", len(data.Symbols), ExpectedSymbolsCount) + } + tests.Symbol(t, data.Symbols) + }) + + t.Run("Signatures", func(t *testing.T) { + t.Helper() + if len(data.Signatures) != ExpectedSignaturesCount { + t.Errorf("got %v signatures expected %v", len(data.Signatures), ExpectedSignaturesCount) + } + tests.Signature(t, data.Signatures) + }) +} + +func (data *Data) collectDiagnostics(spn span.Span, msgSource, msg string) { + if _, ok := data.Diagnostics[spn.URI()]; !ok { + data.Diagnostics[spn.URI()] = []source.Diagnostic{} + } + // If a file has an empty diagnostic message, return. This allows us to + // avoid testing diagnostics in files that may have a lot of them. + if msg == "" { + return + } + severity := source.SeverityError + if strings.Contains(string(spn.URI()), "analyzer") { + severity = source.SeverityWarning + } + want := source.Diagnostic{ + Span: spn, + Severity: severity, + Source: msgSource, + Message: msg, + } + data.Diagnostics[spn.URI()] = append(data.Diagnostics[spn.URI()], want) +} + +func (data *Data) collectCompletions(src span.Span, expected []token.Pos) { + data.Completions[src] = expected +} + +func (data *Data) collectCompletionItems(pos token.Pos, label, detail, kind string) { + data.CompletionItems[pos] = &source.CompletionItem{ + Label: label, + Detail: detail, + Kind: source.ParseCompletionItemKind(kind), + } +} + +func (data *Data) collectFormats(pos token.Position) { + cmd := exec.Command("gofmt", pos.Filename) + stdout := bytes.NewBuffer(nil) + cmd.Stdout = stdout + cmd.Run() // ignore error, sometimes we have intentionally ungofmt-able files + data.Formats[pos.Filename] = stdout.String() +} + +func (data *Data) collectDefinitions(src, target span.Span) { + data.Definitions[src] = Definition{ + Src: src, + Def: target, + } +} + +func (data *Data) collectTypeDefinitions(src, target span.Span) { + data.Definitions[src] = Definition{ + Src: src, + Def: target, + IsType: true, + } +} + +func (data *Data) collectHighlights(name string, rng span.Span) { + data.Highlights[name] = append(data.Highlights[name], rng) +} + +func (data *Data) collectSymbols(name string, spn span.Span, kind string, parentName string) { + sym := source.Symbol{ + Name: name, + Kind: source.ParseSymbolKind(kind), + SelectionSpan: spn, + } + if parentName == "" { + data.Symbols[spn.URI()] = append(data.Symbols[spn.URI()], sym) + } else { + data.symbolsChildren[parentName] = append(data.symbolsChildren[parentName], sym) + } +} + +func (data *Data) collectSignatures(spn span.Span, signature string, activeParam int64) { + data.Signatures[spn] = source.SignatureInformation{ + Label: signature, + ActiveParameter: int(activeParam), + } +}