diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 03c862f3..666792ab 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -38,6 +38,7 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { const expectedDiagnosticsCount = 14 const expectedFormatCount = 3 const expectedDefinitionsCount = 16 + const expectedTypeDefinitionsCount = 2 files := packagestest.MustCopyFileTree(dir) for fragment, operation := range files { @@ -78,6 +79,7 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { expectedCompletions := make(completions) expectedFormat := make(formats) expectedDefinitions := make(definitions) + expectedTypeDefinitions := make(definitions) // Collect any data that needs to be used by subsequent tests. if err := exported.Expect(map[string]interface{}{ @@ -86,6 +88,7 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { "complete": expectedCompletions.collect, "format": expectedFormat.collect, "godef": expectedDefinitions.collect, + "typdef": expectedTypeDefinitions.collect, }); err != nil { t.Fatal(err) } @@ -127,7 +130,17 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { t.Errorf("got %v definitions expected %v", len(expectedDefinitions), expectedDefinitionsCount) } } - expectedDefinitions.test(t, s) + expectedDefinitions.test(t, s, false) + }) + + t.Run("TypeDefinitions", func(t *testing.T) { + t.Helper() + if goVersion111 { // TODO(rstambler): Remove this when we no longer support Go 1.10. + if len(expectedTypeDefinitions) != expectedTypeDefinitionsCount { + t.Errorf("got %v type definitions expected %v", len(expectedTypeDefinitions), expectedTypeDefinitionsCount) + } + } + expectedTypeDefinitions.test(t, s, true) }) } @@ -290,14 +303,21 @@ func (f formats) collect(pos token.Position) { f[pos.Filename] = stdout.String() } -func (d definitions) test(t *testing.T, s *server) { +func (d definitions) test(t *testing.T, s *server, typ bool) { for src, target := range d { - locs, err := s.Definition(context.Background(), &protocol.TextDocumentPositionParams{ + params := &protocol.TextDocumentPositionParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: src.URI, }, Position: src.Range.Start, - }) + } + var locs []protocol.Location + var err error + if typ { + locs, err = s.TypeDefinition(context.Background(), params) + } else { + locs, err = s.Definition(context.Background(), params) + } if err != nil { t.Fatal(err) } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index f5bdf0d2..b8b09525 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -64,6 +64,7 @@ func (s *server) Initialize(ctx context.Context, params *protocol.InitializePara Change: float64(protocol.Full), // full contents of file sent on each update OpenClose: true, }, + TypeDefinitionProvider: true, }, }, nil } @@ -215,8 +216,18 @@ func (s *server) Definition(ctx context.Context, params *protocol.TextDocumentPo return []protocol.Location{toProtocolLocation(s.view.Config.Fset, r)}, nil } -func (s *server) TypeDefinition(context.Context, *protocol.TextDocumentPositionParams) ([]protocol.Location, error) { - return nil, notImplemented("TypeDefinition") +func (s *server) TypeDefinition(ctx context.Context, params *protocol.TextDocumentPositionParams) ([]protocol.Location, error) { + f := s.view.GetFile(source.URI(params.TextDocument.URI)) + tok, err := f.GetToken() + if err != nil { + return nil, err + } + pos := fromProtocolPosition(tok, params.Position) + r, err := source.TypeDefinition(ctx, f, pos) + if err != nil { + return nil, err + } + return []protocol.Location{toProtocolLocation(s.view.Config.Fset, r)}, nil } func (s *server) Implementation(context.Context, *protocol.TextDocumentPositionParams) ([]protocol.Location, error) { diff --git a/internal/lsp/source/definition.go b/internal/lsp/source/definition.go index 43546381..b4f05f48 100644 --- a/internal/lsp/source/definition.go +++ b/internal/lsp/source/definition.go @@ -48,6 +48,43 @@ func Definition(ctx context.Context, f *File, pos token.Pos) (Range, error) { return objToRange(f.view.Config.Fset, obj), nil } +func TypeDefinition(ctx context.Context, f *File, pos token.Pos) (Range, error) { + fAST, err := f.GetAST() + if err != nil { + return Range{}, err + } + pkg, err := f.GetPackage() + if err != nil { + return Range{}, err + } + i, err := findIdentifier(fAST, pos) + if err != nil { + return Range{}, err + } + if i.ident == nil { + return Range{}, fmt.Errorf("not a valid identifier") + } + typ := pkg.TypesInfo.TypeOf(i.ident) + if typ == nil { + return Range{}, fmt.Errorf("no type for %s", i.ident.Name) + } + obj := typeToObject(typ) + if obj == nil { + return Range{}, fmt.Errorf("no object for type %s", typ.String()) + } + return objToRange(f.view.Config.Fset, obj), nil +} + +func typeToObject(typ types.Type) (obj types.Object) { + switch typ := typ.(type) { + case *types.Named: + obj = typ.Obj() + case *types.Pointer: + obj = typeToObject(typ.Elem()) + } + return obj +} + // ident returns the ident plus any extra information needed type ident struct { ident *ast.Ident diff --git a/internal/lsp/testdata/baz/baz.go.in b/internal/lsp/testdata/baz/baz.go.in index 1af3bc4c..90d952be 100644 --- a/internal/lsp/testdata/baz/baz.go.in +++ b/internal/lsp/testdata/baz/baz.go.in @@ -12,7 +12,7 @@ func Baz() { defer bar.Bar() //@complete("B", Bar) // TODO(rstambler): Test completion here. defer bar.B - var _ f.IntFoo //@complete("n", IntFoo) + var x f.IntFoo //@complete("n", IntFoo),typdef("x", IntFoo) bar.Bar() //@complete("B", Bar) } diff --git a/internal/lsp/testdata/foo/foo.go b/internal/lsp/testdata/foo/foo.go index 27c1b42c..e02099b4 100644 --- a/internal/lsp/testdata/foo/foo.go +++ b/internal/lsp/testdata/foo/foo.go @@ -14,7 +14,7 @@ func Foo() { //@item(Foo, "Foo()", "", "func") func _() { var sFoo StructFoo //@complete("t", StructFoo) - if x := sFoo; x.Value == 1 { //@complete("V", Value) + if x := sFoo; x.Value == 1 { //@complete("V", Value),typdef("sFoo", StructFoo) return } }