go.tools/imports: move goimports from github to go.tools.

From revision d0880223588919729793727c9d65f202a73cda77.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/35850048
This commit is contained in:
David Crawshaw 2013-12-17 21:21:03 -05:00
parent 53dfbf8a00
commit c87866116c
8 changed files with 1723 additions and 0 deletions

30
cmd/goimports/README Normal file
View File

@ -0,0 +1,30 @@
This tool updates your Go import lines, adding missing ones and
removing unreferenced ones.
$ go get code.google.com/p/go.tools/cmd/goimports
It's a fork of gofmt, and will also format your code, so it can be
used as a replacement for your gofmt-on-save hook in your editor of
choice.
For emacs, make sure you have the latest (Go 1.2) go-mode.el:
https://go.googlecode.com/hg/misc/emacs/go-mode.el
Then in your .emacs file:
(setq gofmt-command "goimports")
(add-to-list 'load-path "/home/you/goroot/misc/emacs/")
(require 'go-mode-load)
(add-hook 'before-save-hook 'gofmt-before-save)
For vim, set "gofmt_command" to "goimports":
https://code.google.com/p/go/source/detail?r=39c724dd7f252
https://code.google.com/p/go/source/browse#hg%2Fmisc%2Fvim
etc
For GoSublime, follow the steps described here:
http://mdwhatcott.wordpress.com/2013/12/15/installing-and-enabling-goimports-with-gosublime/
For other editors, you probably know what to do.
Happy hacking!

193
cmd/goimports/goimports.go Normal file
View File

@ -0,0 +1,193 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"flag"
"fmt"
"go/scanner"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"code.google.com/p/go.tools/imports"
)
var (
// main operation modes
list = flag.Bool("l", false, "list files whose formatting differs from goimport's")
write = flag.Bool("w", false, "write result to (source) file instead of stdout")
doDiff = flag.Bool("d", false, "display diffs instead of rewriting files")
options = &imports.Options{}
exitCode = 0
)
func init() {
flag.BoolVar(&options.AllErrors, "e", false, "report all errors (not just the first 10 on different lines)")
flag.BoolVar(&options.Comments, "comments", true, "print comments")
flag.IntVar(&options.TabWidth, "tabwidth", 8, "tab width")
flag.BoolVar(&options.TabIndent, "tabs", true, "indent with tabs")
}
func report(err error) {
scanner.PrintError(os.Stderr, err)
exitCode = 2
}
func usage() {
fmt.Fprintf(os.Stderr, "usage: goimports [flags] [path ...]\n")
flag.PrintDefaults()
os.Exit(2)
}
func isGoFile(f os.FileInfo) bool {
// ignore non-Go files
name := f.Name()
return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
}
func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error {
opt := options
if stdin {
nopt := *options
nopt.Fragment = true
opt = &nopt
}
if in == nil {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
in = f
}
src, err := ioutil.ReadAll(in)
if err != nil {
return err
}
res, err := imports.Process(filename, src, opt)
if err != nil {
return err
}
if !bytes.Equal(src, res) {
// formatting has changed
if *list {
fmt.Fprintln(out, filename)
}
if *write {
err = ioutil.WriteFile(filename, res, 0)
if err != nil {
return err
}
}
if *doDiff {
data, err := diff(src, res)
if err != nil {
return fmt.Errorf("computing diff: %s", err)
}
fmt.Printf("diff %s gofmt/%s\n", filename, filename)
out.Write(data)
}
}
if !*list && !*write && !*doDiff {
_, err = out.Write(res)
}
return err
}
func visitFile(path string, f os.FileInfo, err error) error {
if err == nil && isGoFile(f) {
err = processFile(path, nil, os.Stdout, false)
}
if err != nil {
report(err)
}
return nil
}
func walkDir(path string) {
filepath.Walk(path, visitFile)
}
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
// call gofmtMain in a separate function
// so that it can use defer and have them
// run before the exit.
gofmtMain()
os.Exit(exitCode)
}
func gofmtMain() {
flag.Usage = usage
flag.Parse()
if options.TabWidth < 0 {
fmt.Fprintf(os.Stderr, "negative tabwidth %d\n", options.TabWidth)
exitCode = 2
return
}
if flag.NArg() == 0 {
if err := processFile("<standard input>", os.Stdin, os.Stdout, true); err != nil {
report(err)
}
return
}
for i := 0; i < flag.NArg(); i++ {
path := flag.Arg(i)
switch dir, err := os.Stat(path); {
case err != nil:
report(err)
case dir.IsDir():
walkDir(path)
default:
if err := processFile(path, nil, os.Stdout, false); err != nil {
report(err)
}
}
}
}
func diff(b1, b2 []byte) (data []byte, err error) {
f1, err := ioutil.TempFile("", "gofmt")
if err != nil {
return
}
defer os.Remove(f1.Name())
defer f1.Close()
f2, err := ioutil.TempFile("", "gofmt")
if err != nil {
return
}
defer os.Remove(f2.Name())
defer f2.Close()
f1.Write(b1)
f2.Write(b2)
data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
if len(data) > 0 {
// diff exits with a non-zero status when the files don't match.
// Ignore that failure as long as we get output.
err = nil
}
return
}

328
imports/fix.go Normal file
View File

@ -0,0 +1,328 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package imports
import (
"fmt"
"go/ast"
"go/build"
"go/parser"
"go/token"
"os"
"path"
"path/filepath"
"strings"
"sync"
"code.google.com/p/go.tools/astutil"
)
// importToGroup is a list of functions which map from an import path to
// a group number.
var importToGroup = []func(importPath string) (num int, ok bool){
func(importPath string) (num int, ok bool) {
if strings.HasPrefix(importPath, "appengine") {
return 2, true
}
return
},
func(importPath string) (num int, ok bool) {
if strings.Contains(importPath, ".") {
return 1, true
}
return
},
}
func importGroup(importPath string) int {
for _, fn := range importToGroup {
if n, ok := fn(importPath); ok {
return n
}
}
return 0
}
func fixImports(f *ast.File) (added []string, err error) {
// refs are a set of possible package references currently unsatisified by imports.
// first key: either base package (e.g. "fmt") or renamed package
// second key: referenced package symbol (e.g. "Println")
refs := make(map[string]map[string]bool)
// decls are the current package imports. key is base package or renamed package.
decls := make(map[string]*ast.ImportSpec)
// collect potential uses of packages.
var visitor visitFn
visitor = visitFn(func(node ast.Node) ast.Visitor {
if node == nil {
return visitor
}
switch v := node.(type) {
case *ast.ImportSpec:
if v.Name != nil {
decls[v.Name.Name] = v
} else {
local := importPathToName(strings.Trim(v.Path.Value, `\"`))
decls[local] = v
}
case *ast.SelectorExpr:
xident, ok := v.X.(*ast.Ident)
if !ok {
break
}
if xident.Obj != nil {
// if the parser can resolve it, it's not a package ref
break
}
pkgName := xident.Name
if refs[pkgName] == nil {
refs[pkgName] = make(map[string]bool)
}
if decls[pkgName] == nil {
refs[pkgName][v.Sel.Name] = true
}
}
return visitor
})
ast.Walk(visitor, f)
// Search for imports matching potential package references.
searches := 0
type result struct {
ipath string
err error
}
results := make(chan result)
for pkgName, symbols := range refs {
if len(symbols) == 0 {
continue // skip over packages already imported
}
go func(pkgName string, symbols map[string]bool) {
ipath, err := findImport(pkgName, symbols)
results <- result{ipath, err}
}(pkgName, symbols)
searches++
}
for i := 0; i < searches; i++ {
result := <-results
if result.err != nil {
return nil, result.err
}
if result.ipath != "" {
astutil.AddImport(fset, f, result.ipath)
added = append(added, result.ipath)
}
}
// Nil out any unused ImportSpecs, to be removed in following passes
unusedImport := map[string]bool{}
for pkg, is := range decls {
if refs[pkg] == nil && pkg != "_" && pkg != "." {
unusedImport[strings.Trim(is.Path.Value, `"`)] = true
}
}
for ipath := range unusedImport {
if ipath == "C" {
// Don't remove cgo stuff.
continue
}
astutil.DeleteImport(fset, f, ipath)
}
return added, nil
}
// importPathToName returns the package name for the given import path.
var importPathToName = importPathToNameGoPath
// importPathToNameBasic assumes the package name is the base of import path.
func importPathToNameBasic(importPath string) (packageName string) {
return path.Base(importPath)
}
// importPathToNameGoPath finds out the actual package name, as declared in its .go files.
// If there's a problem, it falls back to using importPathToNameBasic.
func importPathToNameGoPath(importPath string) (packageName string) {
if buildPkg, err := build.Import(importPath, "", 0); err == nil {
return buildPkg.Name
} else {
return importPathToNameBasic(importPath)
}
}
type pkg struct {
importpath string // full pkg import path, e.g. "net/http"
dir string // absolute file path to pkg directory e.g. "/usr/lib/go/src/fmt"
}
var pkgIndexOnce sync.Once
var pkgIndex struct {
sync.Mutex
m map[string][]pkg // shortname => []pkg, e.g "http" => "net/http"
}
func loadPkgIndex() {
pkgIndex.Lock()
pkgIndex.m = make(map[string][]pkg)
pkgIndex.Unlock()
var wg sync.WaitGroup
for _, path := range build.Default.SrcDirs() {
f, err := os.Open(path)
if err != nil {
fmt.Fprint(os.Stderr, err)
continue
}
children, err := f.Readdir(-1)
f.Close()
if err != nil {
fmt.Fprint(os.Stderr, err)
continue
}
for _, child := range children {
if child.IsDir() {
wg.Add(1)
go func(path, name string) {
defer wg.Done()
loadPkg(&wg, path, name)
}(path, child.Name())
}
}
}
wg.Wait()
}
var fset = token.NewFileSet()
func loadPkg(wg *sync.WaitGroup, root, pkgrelpath string) {
importpath := filepath.ToSlash(pkgrelpath)
shortName := importPathToName(importpath)
dir := filepath.Join(root, importpath)
pkgIndex.Lock()
pkgIndex.m[shortName] = append(pkgIndex.m[shortName], pkg{
importpath: importpath,
dir: dir,
})
pkgIndex.Unlock()
pkgDir, err := os.Open(dir)
if err != nil {
return
}
children, err := pkgDir.Readdir(-1)
pkgDir.Close()
if err != nil {
return
}
for _, child := range children {
name := child.Name()
if name == "" {
continue
}
if c := name[0]; c == '.' || ('0' <= c && c <= '9') {
continue
}
if child.IsDir() {
wg.Add(1)
go func(root, name string) {
defer wg.Done()
loadPkg(wg, root, name)
}(root, filepath.Join(importpath, name))
}
}
}
// loadExports returns a list exports for a package.
var loadExports = loadExportsGoPath
func loadExportsGoPath(dir string) map[string]bool {
exports := make(map[string]bool)
buildPkg, err := build.ImportDir(dir, 0)
if err != nil {
if strings.Contains(err.Error(), "no buildable Go source files in") {
return nil
}
fmt.Fprintf(os.Stderr, "could not import %q: %v", dir, err)
return nil
}
for _, file := range buildPkg.GoFiles {
f, err := parser.ParseFile(fset, filepath.Join(dir, file), nil, 0)
if err != nil {
fmt.Fprintf(os.Stderr, "could not parse %q: %v", file, err)
continue
}
for name := range f.Scope.Objects {
if ast.IsExported(name) {
exports[name] = true
}
}
}
return exports
}
// findImport searches for a package with the given symbols.
// If no package is found, findImport returns "".
// Declared as a variable rather than a function so goimports can be easily
// extended by adding a file with an init function.
var findImport = findImportGoPath
func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
pkgIndexOnce.Do(loadPkgIndex)
// Collect exports for packages with matching names.
var wg sync.WaitGroup
var pkgsMu sync.Mutex // guards pkgs
// full importpath => exported symbol => True
// e.g. "net/http" => "Client" => True
pkgs := make(map[string]map[string]bool)
pkgIndex.Lock()
for _, pkg := range pkgIndex.m[pkgName] {
wg.Add(1)
go func(importpath, dir string) {
defer wg.Done()
exports := loadExports(dir)
if exports != nil {
pkgsMu.Lock()
pkgs[importpath] = exports
pkgsMu.Unlock()
}
}(pkg.importpath, pkg.dir)
}
pkgIndex.Unlock()
wg.Wait()
// Filter out packages missing required exported symbols.
for symbol := range symbols {
for importpath, exports := range pkgs {
if !exports[symbol] {
delete(pkgs, importpath)
}
}
}
if len(pkgs) == 0 {
return "", nil
}
// If there are multiple candidate packages, the shortest one wins.
// This is a heuristic to prefer the standard library (e.g. "bytes")
// over e.g. "github.com/foo/bar/bytes".
shortest := ""
for importPath := range pkgs {
if shortest == "" || len(importPath) < len(shortest) {
shortest = importPath
}
}
return shortest, nil
}
type visitFn func(node ast.Node) ast.Visitor
func (fn visitFn) Visit(node ast.Node) ast.Visitor {
return fn(node)
}

531
imports/fix_test.go Normal file
View File

@ -0,0 +1,531 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package imports
import (
"bytes"
"flag"
"go/build"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)
var only = flag.String("only", "", "If non-empty, the fix test to run")
var tests = []struct {
name string
in, out string
}{
// Adding an import to an existing parenthesized import
{
name: "factored_imports_add",
in: `package foo
import (
"fmt"
)
func bar() {
var b bytes.Buffer
fmt.Println(b.String())
}
`,
out: `package foo
import (
"bytes"
"fmt"
)
func bar() {
var b bytes.Buffer
fmt.Println(b.String())
}
`,
},
// Adding an import to an existing parenthesized import,
// verifying it goes into the first section.
{
name: "factored_imports_add_first_sec",
in: `package foo
import (
"fmt"
"appengine"
)
func bar() {
var b bytes.Buffer
_ = appengine.IsDevServer
fmt.Println(b.String())
}
`,
out: `package foo
import (
"bytes"
"fmt"
"appengine"
)
func bar() {
var b bytes.Buffer
_ = appengine.IsDevServer
fmt.Println(b.String())
}
`,
},
// Adding an import to an existing parenthesized import,
// verifying it goes into the first section. (test 2)
{
name: "factored_imports_add_first_sec_2",
in: `package foo
import (
"fmt"
"appengine"
)
func bar() {
_ = math.NaN
_ = fmt.Sprintf
_ = appengine.IsDevServer
}
`,
out: `package foo
import (
"fmt"
"math"
"appengine"
)
func bar() {
_ = math.NaN
_ = fmt.Sprintf
_ = appengine.IsDevServer
}
`,
},
// Adding a new import line, without parens
{
name: "add_import_section",
in: `package foo
func bar() {
var b bytes.Buffer
}
`,
out: `package foo
import "bytes"
func bar() {
var b bytes.Buffer
}
`,
},
// Adding two new imports, which should make a parenthesized import decl.
{
name: "add_import_paren_section",
in: `package foo
func bar() {
_, _ := bytes.Buffer, zip.NewReader
}
`,
out: `package foo
import (
"archive/zip"
"bytes"
)
func bar() {
_, _ := bytes.Buffer, zip.NewReader
}
`,
},
// Make sure we don't add things twice
{
name: "no_double_add",
in: `package foo
func bar() {
_, _ := bytes.Buffer, bytes.NewReader
}
`,
out: `package foo
import "bytes"
func bar() {
_, _ := bytes.Buffer, bytes.NewReader
}
`,
},
// Remove unused imports, 1 of a factored block
{
name: "remove_unused_1_of_2",
in: `package foo
import (
"bytes"
"fmt"
)
func bar() {
_, _ := bytes.Buffer, bytes.NewReader
}
`,
out: `package foo
import "bytes"
func bar() {
_, _ := bytes.Buffer, bytes.NewReader
}
`,
},
// Remove unused imports, 2 of 2
{
name: "remove_unused_2_of_2",
in: `package foo
import (
"bytes"
"fmt"
)
func bar() {
}
`,
out: `package foo
func bar() {
}
`,
},
// Remove unused imports, 1 of 1
{
name: "remove_unused_1_of_1",
in: `package foo
import "fmt"
func bar() {
}
`,
out: `package foo
func bar() {
}
`,
},
// Don't remove empty imports.
{
name: "dont_remove_empty_imports",
in: `package foo
import (
_ "image/png"
_ "image/jpeg"
)
`,
out: `package foo
import (
_ "image/jpeg"
_ "image/png"
)
`,
},
// Don't remove dot imports.
{
name: "dont_remove_dot_imports",
in: `package foo
import (
. "foo"
. "bar"
)
`,
out: `package foo
import (
. "bar"
. "foo"
)
`,
},
// Skip refs the parser can resolve.
{
name: "skip_resolved_refs",
in: `package foo
func f() {
type t struct{ Println func(string) }
fmt := t{Println: func(string) {}}
fmt.Println("foo")
}
`,
out: `package foo
func f() {
type t struct{ Println func(string) }
fmt := t{Println: func(string) {}}
fmt.Println("foo")
}
`,
},
// Do not add a package we already have a resolution for.
{
name: "skip_template",
in: `package foo
import "html/template"
func f() { t = template.New("sometemplate") }
`,
out: `package foo
import "html/template"
func f() { t = template.New("sometemplate") }
`,
},
// Don't touch cgo
{
name: "cgo",
in: `package foo
/*
#include <foo.h>
*/
import "C"
`,
out: `package foo
/*
#include <foo.h>
*/
import "C"
`,
},
// Put some things in their own section
{
name: "make_sections",
in: `package foo
import (
"os"
)
func foo () {
_, _ = os.Args, fmt.Println
_, _ = appengine.FooSomething, user.Current
}
`,
out: `package foo
import (
"fmt"
"os"
"appengine"
"appengine/user"
)
func foo() {
_, _ = os.Args, fmt.Println
_, _ = appengine.FooSomething, user.Current
}
`,
},
// Delete existing empty import block
{
name: "delete_empty_import_block",
in: `package foo
import ()
`,
out: `package foo
`,
},
// Use existing empty import block
{
name: "use_empty_import_block",
in: `package foo
import ()
func f() {
_ = fmt.Println
}
`,
out: `package foo
import "fmt"
func f() {
_ = fmt.Println
}
`,
},
// Blank line before adding new section.
{
name: "blank_line_before_new_group",
in: `package foo
import (
"fmt"
"net"
)
func f() {
_ = net.Dial
_ = fmt.Printf
_ = snappy.Foo
}
`,
out: `package foo
import (
"fmt"
"net"
"code.google.com/p/snappy-go/snappy"
)
func f() {
_ = net.Dial
_ = fmt.Printf
_ = snappy.Foo
}
`,
},
// Blank line between standard library and third-party stuff.
{
name: "blank_line_separating_std_and_third_party",
in: `package foo
import (
"code.google.com/p/snappy-go/snappy"
"fmt"
"net"
)
func f() {
_ = net.Dial
_ = fmt.Printf
_ = snappy.Foo
}
`,
out: `package foo
import (
"fmt"
"net"
"code.google.com/p/snappy-go/snappy"
)
func f() {
_ = net.Dial
_ = fmt.Printf
_ = snappy.Foo
}
`,
},
}
func TestFixImports(t *testing.T) {
simplePkgs := map[string]string{
"fmt": "fmt",
"os": "os",
"math": "math",
"appengine": "appengine",
"user": "appengine/user",
"zip": "archive/zip",
"bytes": "bytes",
"snappy": "code.google.com/p/snappy-go/snappy",
}
findImport = func(pkgName string, symbols map[string]bool) (string, error) {
return simplePkgs[pkgName], nil
}
for _, tt := range tests {
if *only != "" && tt.name != *only {
continue
}
var buf bytes.Buffer
err := processFile("foo.go", strings.NewReader(tt.in), &buf, false)
if err != nil {
t.Errorf("error on %q: %v", tt.name, err)
continue
}
if got := buf.String(); got != tt.out {
t.Errorf("results diff on %q\nGOT:\n%s\nWANT:\n%s\n", tt.name, got, tt.out)
}
}
}
func TestFindImportGoPath(t *testing.T) {
goroot, err := ioutil.TempDir("", "goimports-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(goroot)
// Test against imaginary bits/bytes package in std lib
bytesDir := filepath.Join(goroot, "src", "pkg", "bits", "bytes")
if err := os.MkdirAll(bytesDir, 0755); err != nil {
t.Fatal(err)
}
bytesSrcPath := filepath.Join(bytesDir, "bytes.go")
bytesPkgPath := "bits/bytes"
bytesSrc := []byte(`package bytes
type Buffer2 struct {}
`)
if err := ioutil.WriteFile(bytesSrcPath, bytesSrc, 0775); err != nil {
t.Fatal(err)
}
oldGOROOT := build.Default.GOROOT
oldGOPATH := build.Default.GOPATH
build.Default.GOROOT = goroot
build.Default.GOPATH = ""
defer func() {
build.Default.GOROOT = oldGOROOT
build.Default.GOPATH = oldGOPATH
}()
got, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
if err != nil {
t.Fatal(err)
}
if got != bytesPkgPath {
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, want "%s"`, got, bytesPkgPath)
}
got, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
if err != nil {
t.Fatal(err)
}
if got != "" {
t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, want ""`, got)
}
}

240
imports/imports.go Normal file
View File

@ -0,0 +1,240 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package imports implements a Go pretty-printer (like package "go/format")
// that also adds or removes import statements as necessary.
package imports
import (
"bufio"
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"io"
"regexp"
"strconv"
"strings"
"code.google.com/p/go.tools/astutil"
)
// Options specifies options for processing files.
type Options struct {
Fragment bool // Accept fragement of a source file (no package statement)
AllErrors bool // Report all errors (not just the first 10 on different lines)
Comments bool // Print comments (true if nil *Options provided)
TabIndent bool // Use tabs for indent (true if nil *Options provided)
TabWidth int // Tab width (8 if nil *Options provided)
}
// Process formats and adjusts imports for the provided file.
// If opt is nil the defaults are used.
func Process(filename string, src []byte, opt *Options) ([]byte, error) {
if opt == nil {
opt = &Options{Comments: true, TabIndent: true, TabWidth: 8}
}
fileSet := token.NewFileSet()
file, adjust, err := parse(fileSet, filename, src, opt)
if err != nil {
return nil, err
}
_, err = fixImports(file)
if err != nil {
return nil, err
}
sortImports(fileSet, file)
imps := astutil.Imports(fileSet, file)
var spacesBefore []string // import paths we need spaces before
if len(imps) == 1 {
// We have just one block of imports. See if any are in different groups numbers.
lastGroup := -1
for _, importSpec := range imps[0] {
importPath, _ := strconv.Unquote(importSpec.Path.Value)
groupNum := importGroup(importPath)
if groupNum != lastGroup && lastGroup != -1 {
spacesBefore = append(spacesBefore, importPath)
}
lastGroup = groupNum
}
}
printerMode := printer.UseSpaces
if opt.TabIndent {
printerMode |= printer.TabIndent
}
printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
var buf bytes.Buffer
err = printConfig.Fprint(&buf, fileSet, file)
if err != nil {
return nil, err
}
out := buf.Bytes()
if adjust != nil {
out = adjust(src, out)
}
if len(spacesBefore) > 0 {
out = addImportSpaces(bytes.NewReader(out), spacesBefore)
}
return out, nil
}
// parse parses src, which was read from filename,
// as a Go source file or statement list.
func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) {
parserMode := parser.Mode(0)
if opt.Comments {
parserMode |= parser.ParseComments
}
if opt.AllErrors {
parserMode |= parser.AllErrors
}
// Try as whole source file.
file, err := parser.ParseFile(fset, filename, src, parserMode)
if err == nil {
return file, nil, nil
}
// If the error is that the source file didn't begin with a
// package line and we accept fragmented input, fall through to
// try as a source fragment. Stop and return on any other error.
if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") {
return nil, nil, err
}
// If this is a declaration list, make it a source file
// by inserting a package clause.
// Insert using a ;, not a newline, so that the line numbers
// in psrc match the ones in src.
psrc := append([]byte("package p;"), src...)
file, err = parser.ParseFile(fset, filename, psrc, parserMode)
if err == nil {
adjust := func(orig, src []byte) []byte {
// Remove the package clause.
// Gofmt has turned the ; into a \n.
src = src[len("package p\n"):]
return matchSpace(orig, src)
}
return file, adjust, nil
}
// If the error is that the source file didn't begin with a
// declaration, fall through to try as a statement list.
// Stop and return on any other error.
if !strings.Contains(err.Error(), "expected declaration") {
return nil, nil, err
}
// If this is a statement list, make it a source file
// by inserting a package clause and turning the list
// into a function body. This handles expressions too.
// Insert using a ;, not a newline, so that the line numbers
// in fsrc match the ones in src.
fsrc := append(append([]byte("package p; func _() {"), src...), '}')
file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
if err == nil {
adjust := func(orig, src []byte) []byte {
// Remove the wrapping.
// Gofmt has turned the ; into a \n\n.
src = src[len("package p\n\nfunc _() {"):]
src = src[:len(src)-len("}\n")]
// Gofmt has also indented the function body one level.
// Remove that indent.
src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1)
return matchSpace(orig, src)
}
return file, adjust, nil
}
// Failed, and out of options.
return nil, nil, err
}
func cutSpace(b []byte) (before, middle, after []byte) {
i := 0
for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
i++
}
j := len(b)
for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
j--
}
if i <= j {
return b[:i], b[i:j], b[j:]
}
return nil, nil, b[j:]
}
// matchSpace reformats src to use the same space context as orig.
// 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src.
// 2) matchSpace copies the indentation of the first non-blank line in orig
// to every non-blank line in src.
// 3) matchSpace copies the trailing space from orig and uses it in place
// of src's trailing space.
func matchSpace(orig []byte, src []byte) []byte {
before, _, after := cutSpace(orig)
i := bytes.LastIndex(before, []byte{'\n'})
before, indent := before[:i+1], before[i+1:]
_, src, _ = cutSpace(src)
var b bytes.Buffer
b.Write(before)
for len(src) > 0 {
line := src
if i := bytes.IndexByte(line, '\n'); i >= 0 {
line, src = line[:i+1], line[i+1:]
} else {
src = nil
}
if len(line) > 0 && line[0] != '\n' { // not blank
b.Write(indent)
}
b.Write(line)
}
b.Write(after)
return b.Bytes()
}
var impLine = regexp.MustCompile(`^\s+(?:\w+\s+)?"(.+)"`)
func addImportSpaces(r io.Reader, breaks []string) []byte {
var out bytes.Buffer
sc := bufio.NewScanner(r)
inImports := false
done := false
for sc.Scan() {
s := sc.Text()
if !inImports && !done && strings.HasPrefix(s, "import") {
inImports = true
}
if inImports && (strings.HasPrefix(s, "var") ||
strings.HasPrefix(s, "func") ||
strings.HasPrefix(s, "const") ||
strings.HasPrefix(s, "type")) {
done = true
inImports = false
}
if inImports && len(breaks) > 0 {
if m := impLine.FindStringSubmatch(s); m != nil {
if m[1] == string(breaks[0]) {
out.WriteByte('\n')
breaks = breaks[1:]
}
}
}
fmt.Fprintln(&out, s)
}
return out.Bytes()
}

173
imports/mkindex.go Normal file
View File

@ -0,0 +1,173 @@
// +build ignore
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Command mkindex creates the file "pkgindex.go" containing an index of the Go
// standard library. The file is intended to be built as part of the imports
// package, so that the package may be used in environments where a GOROOT is
// not available (such as App Engine).
package main
import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"strings"
)
var (
pkgIndex = make(map[string][]pkg)
exports = make(map[string]map[string]bool)
)
func main() {
// Don't use GOPATH.
ctx := build.Default
ctx.GOPATH = ""
// Populate pkgIndex global from GOROOT.
for _, path := range ctx.SrcDirs() {
f, err := os.Open(path)
if err != nil {
log.Print(err)
continue
}
children, err := f.Readdir(-1)
f.Close()
if err != nil {
log.Print(err)
continue
}
for _, child := range children {
if child.IsDir() {
loadPkg(path, child.Name())
}
}
}
// Populate exports global.
for _, ps := range pkgIndex {
for _, p := range ps {
e := loadExports(p.dir)
if e != nil {
exports[p.dir] = e
}
}
}
// Construct source file.
var buf bytes.Buffer
fmt.Fprint(&buf, pkgIndexHead)
fmt.Fprintf(&buf, "var pkgIndexMaster = %#v\n", pkgIndex)
fmt.Fprintf(&buf, "var exportsMaster = %#v\n", exports)
src := buf.Bytes()
// Replace main.pkg type name with pkg.
src = bytes.Replace(src, []byte("main.pkg"), []byte("pkg"), -1)
// Replace actual GOROOT with "/go".
src = bytes.Replace(src, []byte(ctx.GOROOT), []byte("/go"), -1)
// Add some line wrapping.
src = bytes.Replace(src, []byte("}, "), []byte("},\n"), -1)
src = bytes.Replace(src, []byte("true, "), []byte("true,\n"), -1)
var err error
src, err = format.Source(src)
if err != nil {
log.Fatal(err)
}
// Write out source file.
err = ioutil.WriteFile("pkgindex.go", src, 0644)
if err != nil {
log.Fatal(err)
}
}
const pkgIndexHead = `package imports
func init() {
pkgIndexOnce.Do(func() {
pkgIndex.m = pkgIndexMaster
})
loadExports = func(dir string) map[string]bool {
return exportsMaster[dir]
}
}
`
type pkg struct {
importpath string // full pkg import path, e.g. "net/http"
dir string // absolute file path to pkg directory e.g. "/usr/lib/go/src/fmt"
}
var fset = token.NewFileSet()
func loadPkg(root, importpath string) {
shortName := path.Base(importpath)
if shortName == "testdata" {
return
}
dir := filepath.Join(root, importpath)
pkgIndex[shortName] = append(pkgIndex[shortName], pkg{
importpath: importpath,
dir: dir,
})
pkgDir, err := os.Open(dir)
if err != nil {
return
}
children, err := pkgDir.Readdir(-1)
pkgDir.Close()
if err != nil {
return
}
for _, child := range children {
name := child.Name()
if name == "" {
continue
}
if c := name[0]; c == '.' || ('0' <= c && c <= '9') {
continue
}
if child.IsDir() {
loadPkg(root, filepath.Join(importpath, name))
}
}
}
func loadExports(dir string) map[string]bool {
exports := make(map[string]bool)
buildPkg, err := build.ImportDir(dir, 0)
if err != nil {
if strings.Contains(err.Error(), "no buildable Go source files in") {
return nil
}
log.Printf("could not import %q: %v", dir, err)
return nil
}
for _, file := range buildPkg.GoFiles {
f, err := parser.ParseFile(fset, filepath.Join(dir, file), nil, 0)
if err != nil {
log.Printf("could not parse %q: %v", file, err)
continue
}
for name := range f.Scope.Objects {
if ast.IsExported(name) {
exports[name] = true
}
}
}
return exports
}

214
imports/sortimports.go Normal file
View File

@ -0,0 +1,214 @@
// +build go1.2
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Hacked up copy of go/ast/import.go
package imports
import (
"go/ast"
"go/token"
"sort"
"strconv"
)
// sortImports sorts runs of consecutive import lines in import blocks in f.
// It also removes duplicate imports when it is possible to do so without data loss.
func sortImports(fset *token.FileSet, f *ast.File) {
for i, d := range f.Decls {
d, ok := d.(*ast.GenDecl)
if !ok || d.Tok != token.IMPORT {
// Not an import declaration, so we're done.
// Imports are always first.
break
}
if len(d.Specs) == 0 {
// Empty import block, remove it.
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
}
if !d.Lparen.IsValid() {
// Not a block: sorted by default.
continue
}
// Identify and sort runs of specs on successive lines.
i := 0
specs := d.Specs[:0]
for j, s := range d.Specs {
if j > i && fset.Position(s.Pos()).Line > 1+fset.Position(d.Specs[j-1].End()).Line {
// j begins a new run. End this one.
specs = append(specs, sortSpecs(fset, f, d.Specs[i:j])...)
i = j
}
}
specs = append(specs, sortSpecs(fset, f, d.Specs[i:])...)
d.Specs = specs
// Deduping can leave a blank line before the rparen; clean that up.
if len(d.Specs) > 0 {
lastSpec := d.Specs[len(d.Specs)-1]
lastLine := fset.Position(lastSpec.Pos()).Line
if rParenLine := fset.Position(d.Rparen).Line; rParenLine > lastLine+1 {
fset.File(d.Rparen).MergeLine(rParenLine - 1)
}
}
}
}
func importPath(s ast.Spec) string {
t, err := strconv.Unquote(s.(*ast.ImportSpec).Path.Value)
if err == nil {
return t
}
return ""
}
func importName(s ast.Spec) string {
n := s.(*ast.ImportSpec).Name
if n == nil {
return ""
}
return n.Name
}
func importComment(s ast.Spec) string {
c := s.(*ast.ImportSpec).Comment
if c == nil {
return ""
}
return c.Text()
}
// collapse indicates whether prev may be removed, leaving only next.
func collapse(prev, next ast.Spec) bool {
if importPath(next) != importPath(prev) || importName(next) != importName(prev) {
return false
}
return prev.(*ast.ImportSpec).Comment == nil
}
type posSpan struct {
Start token.Pos
End token.Pos
}
func sortSpecs(fset *token.FileSet, f *ast.File, specs []ast.Spec) []ast.Spec {
// Can't short-circuit here even if specs are already sorted,
// since they might yet need deduplication.
// A lone import, however, may be safely ignored.
if len(specs) <= 1 {
return specs
}
// Record positions for specs.
pos := make([]posSpan, len(specs))
for i, s := range specs {
pos[i] = posSpan{s.Pos(), s.End()}
}
// Identify comments in this range.
// Any comment from pos[0].Start to the final line counts.
lastLine := fset.Position(pos[len(pos)-1].End).Line
cstart := len(f.Comments)
cend := len(f.Comments)
for i, g := range f.Comments {
if g.Pos() < pos[0].Start {
continue
}
if i < cstart {
cstart = i
}
if fset.Position(g.End()).Line > lastLine {
cend = i
break
}
}
comments := f.Comments[cstart:cend]
// Assign each comment to the import spec preceding it.
importComment := map[*ast.ImportSpec][]*ast.CommentGroup{}
specIndex := 0
for _, g := range comments {
for specIndex+1 < len(specs) && pos[specIndex+1].Start <= g.Pos() {
specIndex++
}
s := specs[specIndex].(*ast.ImportSpec)
importComment[s] = append(importComment[s], g)
}
// Sort the import specs by import path.
// Remove duplicates, when possible without data loss.
// Reassign the import paths to have the same position sequence.
// Reassign each comment to abut the end of its spec.
// Sort the comments by new position.
sort.Sort(byImportSpec(specs))
// Dedup. Thanks to our sorting, we can just consider
// adjacent pairs of imports.
deduped := specs[:0]
for i, s := range specs {
if i == len(specs)-1 || !collapse(s, specs[i+1]) {
deduped = append(deduped, s)
} else {
p := s.Pos()
fset.File(p).MergeLine(fset.Position(p).Line)
}
}
specs = deduped
// Fix up comment positions
for i, s := range specs {
s := s.(*ast.ImportSpec)
if s.Name != nil {
s.Name.NamePos = pos[i].Start
}
s.Path.ValuePos = pos[i].Start
s.EndPos = pos[i].End
for _, g := range importComment[s] {
for _, c := range g.List {
c.Slash = pos[i].End
}
}
}
sort.Sort(byCommentPos(comments))
return specs
}
type byImportSpec []ast.Spec // slice of *ast.ImportSpec
func (x byImportSpec) Len() int { return len(x) }
func (x byImportSpec) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byImportSpec) Less(i, j int) bool {
ipath := importPath(x[i])
jpath := importPath(x[j])
igroup := importGroup(ipath)
jgroup := importGroup(jpath)
if igroup != jgroup {
return igroup < jgroup
}
if ipath != jpath {
return ipath < jpath
}
iname := importName(x[i])
jname := importName(x[j])
if iname != jname {
return iname < jname
}
return importComment(x[i]) < importComment(x[j])
}
type byCommentPos []*ast.CommentGroup
func (x byCommentPos) Len() int { return len(x) }
func (x byCommentPos) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byCommentPos) Less(i, j int) bool { return x[i].Pos() < x[j].Pos() }

View File

@ -0,0 +1,14 @@
// +build !go1.2
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package imports
import "go/ast"
// Go 1.1 users don't get fancy package grouping.
// But this is still gofmt-compliant:
var sortImports = ast.SortImports