diff --git a/internal/lsp/cache/file.go b/internal/lsp/cache/file.go index 1a12a383..d75c6916 100644 --- a/internal/lsp/cache/file.go +++ b/internal/lsp/cache/file.go @@ -24,11 +24,12 @@ type File struct { pkg *packages.Package } -// Read returns the contents of the file, reading it from file system if needed. -func (f *File) Read() ([]byte, error) { +// GetContent returns the contents of the file, reading it from file system if needed. +func (f *File) GetContent() []byte { f.view.mu.Lock() defer f.view.mu.Unlock() - return f.read() + f.read() + return f.content } func (f *File) GetFileSet() *token.FileSet { @@ -69,19 +70,18 @@ func (f *File) GetPackage() *packages.Package { } // read is the internal part of Read that presumes the lock is already held -func (f *File) read() ([]byte, error) { +func (f *File) read() { if f.content != nil { - return f.content, nil + return } // we don't know the content yet, so read it filename, err := f.URI.Filename() if err != nil { - return nil, err + return } content, err := ioutil.ReadFile(filename) if err != nil { - return nil, err + return } f.content = content - return f.content, nil } diff --git a/internal/lsp/format.go b/internal/lsp/format.go index 650f72e1..be88dacf 100644 --- a/internal/lsp/format.go +++ b/internal/lsp/format.go @@ -2,7 +2,6 @@ package lsp import ( "context" - "go/token" "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" @@ -26,21 +25,19 @@ func formatRange(ctx context.Context, v source.View, uri protocol.DocumentURI, r } else { r = fromProtocolRange(tok, *rng) } - content, err := f.Read() - if err != nil { - return nil, err - } edits, err := source.Format(ctx, f, r) if err != nil { return nil, err } - return toProtocolEdits(tok, content, edits), nil + return toProtocolEdits(f, edits), nil } -func toProtocolEdits(tok *token.File, content []byte, edits []source.TextEdit) []protocol.TextEdit { +func toProtocolEdits(f source.File, edits []source.TextEdit) []protocol.TextEdit { if edits == nil { return nil } + tok := f.GetToken() + content := f.GetContent() // When a file ends with an empty line, the newline character is counted // as part of the previous line. This causes the formatter to insert // another unnecessary newline on each formatting. We handle this case by diff --git a/internal/lsp/imports.go b/internal/lsp/imports.go index 3b57db3b..fad98c82 100644 --- a/internal/lsp/imports.go +++ b/internal/lsp/imports.go @@ -25,13 +25,9 @@ func organizeImports(ctx context.Context, v source.View, uri protocol.DocumentUR Start: tok.Pos(0), End: tok.Pos(tok.Size()), } - content, err := f.Read() + edits, err := source.Imports(ctx, f, r) if err != nil { return nil, err } - edits, err := source.Imports(ctx, tok, content, r) - if err != nil { - return nil, err - } - return toProtocolEdits(tok, content, edits), nil + return toProtocolEdits(f, edits), nil } diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 5dc80a96..75d22925 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -374,10 +374,6 @@ func (f formats) test(t *testing.T, s *server) { if err != nil { t.Error(err) } - original, err := f.Read() - if err != nil { - t.Error(err) - } var ops []*diff.Op for _, edit := range edits { start := int(edit.Range.Start.Line) @@ -400,7 +396,7 @@ func (f formats) test(t *testing.T, s *server) { }) } } - split := strings.SplitAfter(string(original), "\n") + split := strings.SplitAfter(string(f.GetContent()), "\n") got := strings.Join(diff.ApplyEdits(split, ops), "") if gofmted != got { t.Errorf("format failed for %s: expected '%v', got '%v'", filename, gofmted, got) diff --git a/internal/lsp/source/diagnostics.go b/internal/lsp/source/diagnostics.go index 9fa26b7a..99a8fd9f 100644 --- a/internal/lsp/source/diagnostics.go +++ b/internal/lsp/source/diagnostics.go @@ -80,11 +80,7 @@ func Diagnostics(ctx context.Context, v View, uri URI) (map[string][]Diagnostic, continue } diagTok := diagFile.GetToken() - content, err := diagFile.Read() - if err != nil { - continue - } - end, err := identifierEnd(content, pos.Line, pos.Column) + end, err := identifierEnd(diagFile.GetContent(), pos.Line, pos.Column) // Don't set a range if it's anything other than a type error. if err != nil || diag.Kind != packages.TypeError { end = 0 diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go index ac5a304c..3c4ae69c 100644 --- a/internal/lsp/source/format.go +++ b/internal/lsp/source/format.go @@ -52,25 +52,21 @@ func Format(ctx context.Context, f File, rng Range) ([]TextEdit, error) { if err := format.Node(buf, fset, node); err != nil { return nil, err } - content, err := f.Read() - if err != nil { - return nil, err - } - tok := f.GetToken() - return computeTextEdits(rng, tok, string(content), buf.String()), nil + return computeTextEdits(rng, f, buf.String()), nil } // Imports formats a file using the goimports tool. -func Imports(ctx context.Context, tok *token.File, content []byte, rng Range) ([]TextEdit, error) { - formatted, err := imports.Process(tok.Name(), content, nil) +func Imports(ctx context.Context, f File, rng Range) ([]TextEdit, error) { + formatted, err := imports.Process(f.GetToken().Name(), f.GetContent(), nil) if err != nil { return nil, err } - return computeTextEdits(rng, tok, string(content), string(formatted)), nil + return computeTextEdits(rng, f, string(formatted)), nil } -func computeTextEdits(rng Range, tok *token.File, unformatted, formatted string) (edits []TextEdit) { - u := strings.SplitAfter(unformatted, "\n") +func computeTextEdits(rng Range, file File, formatted string) (edits []TextEdit) { + u := strings.SplitAfter(string(file.GetContent()), "\n") + tok := file.GetToken() f := strings.SplitAfter(formatted, "\n") for _, op := range diff.Operations(u, f) { start := lineStart(tok, op.I1+1) diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go index eaac87dd..5a427170 100644 --- a/internal/lsp/source/view.go +++ b/internal/lsp/source/view.go @@ -31,7 +31,7 @@ type File interface { GetFileSet() *token.FileSet GetPackage() *packages.Package GetToken() *token.File - Read() ([]byte, error) + GetContent() []byte } // Range represents a start and end position.