diff --git a/internal/lsp/cmd/cmd_test.go b/internal/lsp/cmd/cmd_test.go index 8863145b..c1a882ff 100644 --- a/internal/lsp/cmd/cmd_test.go +++ b/internal/lsp/cmd/cmd_test.go @@ -49,6 +49,11 @@ func (r *runner) Completion(t *testing.T, data tests.Completions, snippets tests func (r *runner) Highlight(t *testing.T, data tests.Highlights) { //TODO: add command line highlight tests when it works } + +func (r *runner) Reference(t *testing.T, data tests.References) { + //TODO: add command line references tests when it works +} + func (r *runner) Symbol(t *testing.T, data tests.Symbols) { //TODO: add command line symbol tests when it works } diff --git a/internal/lsp/general.go b/internal/lsp/general.go index 0964de70..92e28a53 100644 --- a/internal/lsp/general.go +++ b/internal/lsp/general.go @@ -69,6 +69,7 @@ func (s *Server) initialize(ctx context.Context, params *protocol.InitializePara HoverProvider: true, DocumentHighlightProvider: true, DocumentLinkProvider: &protocol.DocumentLinkOptions{}, + ReferencesProvider: true, SignatureHelpProvider: &protocol.SignatureHelpOptions{ TriggerCharacters: []string{"(", ","}, }, diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 2548b46a..a7d12b0f 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -460,6 +460,48 @@ func (r *runner) Highlight(t *testing.T, data tests.Highlights) { } } +func (r *runner) Reference(t *testing.T, data tests.References) { + for src, itemList := range data { + sm, err := r.mapper(src.URI()) + if err != nil { + t.Fatal(err) + } + loc, err := sm.Location(src) + if err != nil { + t.Fatalf("failed for %v: %v", src, err) + } + + want := make(map[protocol.Location]bool) + for _, pos := range itemList { + loc, err := sm.Location(pos) + if err != nil { + t.Fatalf("failed for %v: %v", src, err) + } + want[loc] = true + } + + params := &protocol.ReferenceParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI}, + Position: loc.Range.Start, + }, + } + got, err := r.server.References(context.Background(), params) + if err != nil { + t.Fatalf("failed for %v: %v", src, err) + } + + if len(got) != len(itemList) { + t.Errorf("references failed: different lengths got %v want %v", len(got), len(itemList)) + } + for _, loc := range got { + if !want[loc] { + t.Errorf("references failed: incorrect references got %v want %v", got, want) + } + } + } +} + func (r *runner) Symbol(t *testing.T, data tests.Symbols) { for uri, expectedSymbols := range data { params := &protocol.DocumentSymbolParams{ diff --git a/internal/lsp/references.go b/internal/lsp/references.go new file mode 100644 index 00000000..27862441 --- /dev/null +++ b/internal/lsp/references.go @@ -0,0 +1,56 @@ +package lsp + +import ( + "context" + + "golang.org/x/tools/internal/lsp/protocol" + "golang.org/x/tools/internal/lsp/source" + "golang.org/x/tools/internal/span" +) + +func (s *Server) references(ctx context.Context, params *protocol.ReferenceParams) ([]protocol.Location, error) { + uri := span.NewURI(params.TextDocument.URI) + view := s.session.ViewOf(uri) + f, m, err := getGoFile(ctx, view, uri) + if err != nil { + return nil, err + } + spn, err := m.PointSpan(params.Position) + if err != nil { + return nil, err + } + rng, err := spn.Range(m.Converter) + if err != nil { + return nil, err + } + + // Find all references to the identifier at the position. + ident, err := source.Identifier(ctx, view, f, rng.Start) + if err != nil { + return nil, err + } + references, err := ident.References(ctx) + if err != nil { + return nil, err + } + + // Get the location of each reference to return as the result. + locations := make([]protocol.Location, 0, len(references)) + for _, ref := range references { + refSpan, err := ref.Range.Span() + if err != nil { + return nil, err + } + _, refM, err := getSourceFile(ctx, view, refSpan.URI()) + if err != nil { + return nil, err + } + loc, err := refM.Location(refSpan) + if err != nil { + return nil, err + } + + locations = append(locations, loc) + } + return locations, nil +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 13bcaefe..121cc13e 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -186,8 +186,8 @@ func (s *Server) Implementation(context.Context, *protocol.TextDocumentPositionP return nil, notImplemented("Implementation") } -func (s *Server) References(context.Context, *protocol.ReferenceParams) ([]protocol.Location, error) { - return nil, notImplemented("References") +func (s *Server) References(ctx context.Context, params *protocol.ReferenceParams) ([]protocol.Location, error) { + return s.references(ctx, params) } func (s *Server) DocumentHighlight(ctx context.Context, params *protocol.TextDocumentPositionParams) ([]protocol.DocumentHighlight, error) { diff --git a/internal/lsp/source/references.go b/internal/lsp/source/references.go new file mode 100644 index 00000000..8a3ae736 --- /dev/null +++ b/internal/lsp/source/references.go @@ -0,0 +1,56 @@ +package source + +import ( + "context" + "fmt" + "go/ast" + + "golang.org/x/tools/internal/span" +) + +// ReferenceInfo holds information about reference to an identifier in Go source. +type ReferenceInfo struct { + Name string + Range span.Range + ident *ast.Ident +} + +// References returns a list of references for a given identifier within a package. +func (i *IdentifierInfo) References(ctx context.Context) ([]*ReferenceInfo, error) { + pkg := i.File.GetPackage(ctx) + if pkg == nil || pkg.IsIllTyped() { + return nil, fmt.Errorf("package for %s is ill typed", i.File.URI()) + } + pkgInfo := pkg.GetTypesInfo() + if pkgInfo == nil { + return nil, fmt.Errorf("package %s has no types info", pkg.PkgPath()) + } + + // If the object declaration is nil, assume it is an import spec and do not look for references. + declObj := i.decl.obj + if declObj == nil { + return []*ReferenceInfo{}, nil + } + + var references []*ReferenceInfo + for ident, obj := range pkgInfo.Defs { + if obj == declObj { + references = append(references, &ReferenceInfo{ + Name: ident.Name, + Range: span.NewRange(i.File.FileSet(), ident.Pos(), ident.End()), + ident: ident, + }) + } + } + for ident, obj := range pkgInfo.Uses { + if obj == declObj { + references = append(references, &ReferenceInfo{ + Name: ident.Name, + Range: span.NewRange(i.File.FileSet(), ident.Pos(), ident.End()), + ident: ident, + }) + } + } + + return references, nil +} diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index 8135e1db..be9c0f6c 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -417,6 +417,46 @@ func (r *runner) Highlight(t *testing.T, data tests.Highlights) { } } +func (r *runner) Reference(t *testing.T, data tests.References) { + ctx := context.Background() + for src, itemList := range data { + f, err := r.view.GetFile(ctx, src.URI()) + if err != nil { + t.Fatalf("failed for %v: %v", src, err) + } + + tok := f.GetToken(ctx) + pos := tok.Pos(src.Start().Offset()) + ident, err := source.Identifier(ctx, r.view, f.(source.GoFile), pos) + if err != nil { + t.Fatalf("failed for %v: %v", src, err) + } + + want := make(map[span.Span]bool) + for _, pos := range itemList { + want[pos] = true + } + + got, err := ident.References(ctx) + if err != nil { + t.Fatalf("failed for %v: %v", src, err) + } + + if len(got) != len(itemList) { + t.Errorf("references failed: different lengths got %v want %v", len(got), len(itemList)) + } + for _, refInfo := range got { + refSpan, err := refInfo.Range.Span() + if err != nil { + t.Errorf("failed for %v item %v: %v", src, refInfo.Name, err) + } + if !want[refSpan] { + t.Errorf("references failed: incorrect references got %v want locations %v", got, want) + } + } + } +} + func (r *runner) Symbol(t *testing.T, data tests.Symbols) { ctx := context.Background() for uri, expectedSymbols := range data { diff --git a/internal/lsp/testdata/foo/foo.go b/internal/lsp/testdata/foo/foo.go index 0e33467f..094623e1 100644 --- a/internal/lsp/testdata/foo/foo.go +++ b/internal/lsp/testdata/foo/foo.go @@ -13,8 +13,8 @@ func Foo() { //@item(Foo, "Foo()", "", "func") } func _() { - var sFoo StructFoo //@complete("t", StructFoo) - if x := sFoo; x.Value == 1 { //@complete("V", Value),typdef("sFoo", StructFoo) + var sFoo StructFoo //@mark(sFoo1, "sFoo"),complete("t", StructFoo) + if x := sFoo; x.Value == 1 { //@mark(sFoo2, "sFoo"),complete("V", Value),typdef("sFoo", StructFoo),refs("sFo", sFoo1, sFoo2) return } } @@ -22,7 +22,7 @@ func _() { func _() { shadowed := 123 { - shadowed := "hi" //@item(shadowed, "shadowed", "string", "var") + shadowed := "hi" //@item(shadowed, "shadowed", "string", "var"),refs("shadowed", shadowed) sha //@complete("a", shadowed) } } diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index a8dc3495..81f1a107 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -34,6 +34,7 @@ const ( ExpectedDefinitionsCount = 35 ExpectedTypeDefinitionsCount = 2 ExpectedHighlightsCount = 2 + ExpectedReferencesCount = 2 ExpectedSymbolsCount = 1 ExpectedSignaturesCount = 20 ExpectedLinksCount = 2 @@ -56,6 +57,7 @@ type Formats []span.Span type Imports []span.Span type Definitions map[span.Span]Definition type Highlights map[string][]span.Span +type References map[span.Span][]span.Span type Symbols map[span.URI][]source.Symbol type SymbolsChildren map[string][]source.Symbol type Signatures map[span.Span]source.SignatureInformation @@ -72,6 +74,7 @@ type Data struct { Imports Imports Definitions Definitions Highlights Highlights + References References Symbols Symbols symbolsChildren SymbolsChildren Signatures Signatures @@ -90,6 +93,7 @@ type Tests interface { Import(*testing.T, Imports) Definition(*testing.T, Definitions) Highlight(*testing.T, Highlights) + Reference(*testing.T, References) Symbol(*testing.T, Symbols) SignatureHelp(*testing.T, Signatures) Link(*testing.T, Links) @@ -130,6 +134,7 @@ func Load(t testing.TB, exporter packagestest.Exporter, dir string) *Data { CompletionSnippets: make(CompletionSnippets), Definitions: make(Definitions), Highlights: make(Highlights), + References: make(References), Symbols: make(Symbols), symbolsChildren: make(SymbolsChildren), Signatures: make(Signatures), @@ -209,6 +214,7 @@ func Load(t testing.TB, exporter packagestest.Exporter, dir string) *Data { "typdef": data.collectTypeDefinitions, "hover": data.collectHoverDefinitions, "highlight": data.collectHighlights, + "refs": data.collectReferences, "symbol": data.collectSymbols, "signature": data.collectSignatures, "snippet": data.collectCompletionSnippets, @@ -289,6 +295,14 @@ func Run(t *testing.T, tests Tests, data *Data) { tests.Highlight(t, data.Highlights) }) + t.Run("References", func(t *testing.T) { + t.Helper() + if len(data.References) != ExpectedReferencesCount { + t.Errorf("got %v references expected %v", len(data.References), ExpectedReferencesCount) + } + tests.Reference(t, data.References) + }) + t.Run("Symbols", func(t *testing.T) { t.Helper() if len(data.Symbols) != ExpectedSymbolsCount { @@ -456,6 +470,10 @@ func (data *Data) collectHighlights(name string, rng span.Span) { data.Highlights[name] = append(data.Highlights[name], rng) } +func (data *Data) collectReferences(src span.Span, expected []span.Span) { + data.References[src] = expected +} + func (data *Data) collectSymbols(name string, spn span.Span, kind string, parentName string) { sym := source.Symbol{ Name: name,