From c87866116c5004382833e51aa08d938a6f6ff120 Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Tue, 17 Dec 2013 21:21:03 -0500 Subject: [PATCH] go.tools/imports: move goimports from github to go.tools. From revision d0880223588919729793727c9d65f202a73cda77. R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/35850048 --- cmd/goimports/README | 30 ++ cmd/goimports/goimports.go | 193 ++++++++++++ imports/fix.go | 328 +++++++++++++++++++++ imports/fix_test.go | 531 ++++++++++++++++++++++++++++++++++ imports/imports.go | 240 +++++++++++++++ imports/mkindex.go | 173 +++++++++++ imports/sortimports.go | 214 ++++++++++++++ imports/sortimports_compat.go | 14 + 8 files changed, 1723 insertions(+) create mode 100644 cmd/goimports/README create mode 100644 cmd/goimports/goimports.go create mode 100644 imports/fix.go create mode 100644 imports/fix_test.go create mode 100644 imports/imports.go create mode 100644 imports/mkindex.go create mode 100644 imports/sortimports.go create mode 100644 imports/sortimports_compat.go diff --git a/cmd/goimports/README b/cmd/goimports/README new file mode 100644 index 00000000..6a4a8106 --- /dev/null +++ b/cmd/goimports/README @@ -0,0 +1,30 @@ +This tool updates your Go import lines, adding missing ones and +removing unreferenced ones. + + $ go get code.google.com/p/go.tools/cmd/goimports + +It's a fork of gofmt, and will also format your code, so it can be +used as a replacement for your gofmt-on-save hook in your editor of +choice. + +For emacs, make sure you have the latest (Go 1.2) go-mode.el: + https://go.googlecode.com/hg/misc/emacs/go-mode.el + +Then in your .emacs file: + (setq gofmt-command "goimports") + (add-to-list 'load-path "/home/you/goroot/misc/emacs/") + (require 'go-mode-load) + (add-hook 'before-save-hook 'gofmt-before-save) + +For vim, set "gofmt_command" to "goimports": + + https://code.google.com/p/go/source/detail?r=39c724dd7f252 + https://code.google.com/p/go/source/browse#hg%2Fmisc%2Fvim + etc + +For GoSublime, follow the steps described here: + http://mdwhatcott.wordpress.com/2013/12/15/installing-and-enabling-goimports-with-gosublime/ + +For other editors, you probably know what to do. + +Happy hacking! diff --git a/cmd/goimports/goimports.go b/cmd/goimports/goimports.go new file mode 100644 index 00000000..1f0b4419 --- /dev/null +++ b/cmd/goimports/goimports.go @@ -0,0 +1,193 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/scanner" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + + "code.google.com/p/go.tools/imports" +) + +var ( + // main operation modes + list = flag.Bool("l", false, "list files whose formatting differs from goimport's") + write = flag.Bool("w", false, "write result to (source) file instead of stdout") + doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") + + options = &imports.Options{} + exitCode = 0 +) + +func init() { + flag.BoolVar(&options.AllErrors, "e", false, "report all errors (not just the first 10 on different lines)") + flag.BoolVar(&options.Comments, "comments", true, "print comments") + flag.IntVar(&options.TabWidth, "tabwidth", 8, "tab width") + flag.BoolVar(&options.TabIndent, "tabs", true, "indent with tabs") +} + +func report(err error) { + scanner.PrintError(os.Stderr, err) + exitCode = 2 +} + +func usage() { + fmt.Fprintf(os.Stderr, "usage: goimports [flags] [path ...]\n") + flag.PrintDefaults() + os.Exit(2) +} + +func isGoFile(f os.FileInfo) bool { + // ignore non-Go files + name := f.Name() + return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") +} + +func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error { + opt := options + if stdin { + nopt := *options + nopt.Fragment = true + opt = &nopt + } + + if in == nil { + f, err := os.Open(filename) + if err != nil { + return err + } + defer f.Close() + in = f + } + + src, err := ioutil.ReadAll(in) + if err != nil { + return err + } + + res, err := imports.Process(filename, src, opt) + if err != nil { + return err + } + + if !bytes.Equal(src, res) { + // formatting has changed + if *list { + fmt.Fprintln(out, filename) + } + if *write { + err = ioutil.WriteFile(filename, res, 0) + if err != nil { + return err + } + } + if *doDiff { + data, err := diff(src, res) + if err != nil { + return fmt.Errorf("computing diff: %s", err) + } + fmt.Printf("diff %s gofmt/%s\n", filename, filename) + out.Write(data) + } + } + + if !*list && !*write && !*doDiff { + _, err = out.Write(res) + } + + return err +} + +func visitFile(path string, f os.FileInfo, err error) error { + if err == nil && isGoFile(f) { + err = processFile(path, nil, os.Stdout, false) + } + if err != nil { + report(err) + } + return nil +} + +func walkDir(path string) { + filepath.Walk(path, visitFile) +} + +func main() { + runtime.GOMAXPROCS(runtime.NumCPU()) + + // call gofmtMain in a separate function + // so that it can use defer and have them + // run before the exit. + gofmtMain() + os.Exit(exitCode) +} + +func gofmtMain() { + flag.Usage = usage + flag.Parse() + + if options.TabWidth < 0 { + fmt.Fprintf(os.Stderr, "negative tabwidth %d\n", options.TabWidth) + exitCode = 2 + return + } + + if flag.NArg() == 0 { + if err := processFile("", os.Stdin, os.Stdout, true); err != nil { + report(err) + } + return + } + + for i := 0; i < flag.NArg(); i++ { + path := flag.Arg(i) + switch dir, err := os.Stat(path); { + case err != nil: + report(err) + case dir.IsDir(): + walkDir(path) + default: + if err := processFile(path, nil, os.Stdout, false); err != nil { + report(err) + } + } + } +} + +func diff(b1, b2 []byte) (data []byte, err error) { + f1, err := ioutil.TempFile("", "gofmt") + if err != nil { + return + } + defer os.Remove(f1.Name()) + defer f1.Close() + + f2, err := ioutil.TempFile("", "gofmt") + if err != nil { + return + } + defer os.Remove(f2.Name()) + defer f2.Close() + + f1.Write(b1) + f2.Write(b2) + + data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput() + if len(data) > 0 { + // diff exits with a non-zero status when the files don't match. + // Ignore that failure as long as we get output. + err = nil + } + return +} diff --git a/imports/fix.go b/imports/fix.go new file mode 100644 index 00000000..12d382a0 --- /dev/null +++ b/imports/fix.go @@ -0,0 +1,328 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package imports + +import ( + "fmt" + "go/ast" + "go/build" + "go/parser" + "go/token" + "os" + "path" + "path/filepath" + "strings" + "sync" + + "code.google.com/p/go.tools/astutil" +) + +// importToGroup is a list of functions which map from an import path to +// a group number. +var importToGroup = []func(importPath string) (num int, ok bool){ + func(importPath string) (num int, ok bool) { + if strings.HasPrefix(importPath, "appengine") { + return 2, true + } + return + }, + func(importPath string) (num int, ok bool) { + if strings.Contains(importPath, ".") { + return 1, true + } + return + }, +} + +func importGroup(importPath string) int { + for _, fn := range importToGroup { + if n, ok := fn(importPath); ok { + return n + } + } + return 0 +} + +func fixImports(f *ast.File) (added []string, err error) { + // refs are a set of possible package references currently unsatisified by imports. + // first key: either base package (e.g. "fmt") or renamed package + // second key: referenced package symbol (e.g. "Println") + refs := make(map[string]map[string]bool) + + // decls are the current package imports. key is base package or renamed package. + decls := make(map[string]*ast.ImportSpec) + + // collect potential uses of packages. + var visitor visitFn + visitor = visitFn(func(node ast.Node) ast.Visitor { + if node == nil { + return visitor + } + switch v := node.(type) { + case *ast.ImportSpec: + if v.Name != nil { + decls[v.Name.Name] = v + } else { + local := importPathToName(strings.Trim(v.Path.Value, `\"`)) + decls[local] = v + } + case *ast.SelectorExpr: + xident, ok := v.X.(*ast.Ident) + if !ok { + break + } + if xident.Obj != nil { + // if the parser can resolve it, it's not a package ref + break + } + pkgName := xident.Name + if refs[pkgName] == nil { + refs[pkgName] = make(map[string]bool) + } + if decls[pkgName] == nil { + refs[pkgName][v.Sel.Name] = true + } + } + return visitor + }) + ast.Walk(visitor, f) + + // Search for imports matching potential package references. + searches := 0 + type result struct { + ipath string + err error + } + results := make(chan result) + for pkgName, symbols := range refs { + if len(symbols) == 0 { + continue // skip over packages already imported + } + go func(pkgName string, symbols map[string]bool) { + ipath, err := findImport(pkgName, symbols) + results <- result{ipath, err} + }(pkgName, symbols) + searches++ + } + for i := 0; i < searches; i++ { + result := <-results + if result.err != nil { + return nil, result.err + } + if result.ipath != "" { + astutil.AddImport(fset, f, result.ipath) + added = append(added, result.ipath) + } + } + + // Nil out any unused ImportSpecs, to be removed in following passes + unusedImport := map[string]bool{} + for pkg, is := range decls { + if refs[pkg] == nil && pkg != "_" && pkg != "." { + unusedImport[strings.Trim(is.Path.Value, `"`)] = true + } + } + for ipath := range unusedImport { + if ipath == "C" { + // Don't remove cgo stuff. + continue + } + astutil.DeleteImport(fset, f, ipath) + } + + return added, nil +} + +// importPathToName returns the package name for the given import path. +var importPathToName = importPathToNameGoPath + +// importPathToNameBasic assumes the package name is the base of import path. +func importPathToNameBasic(importPath string) (packageName string) { + return path.Base(importPath) +} + +// importPathToNameGoPath finds out the actual package name, as declared in its .go files. +// If there's a problem, it falls back to using importPathToNameBasic. +func importPathToNameGoPath(importPath string) (packageName string) { + if buildPkg, err := build.Import(importPath, "", 0); err == nil { + return buildPkg.Name + } else { + return importPathToNameBasic(importPath) + } +} + +type pkg struct { + importpath string // full pkg import path, e.g. "net/http" + dir string // absolute file path to pkg directory e.g. "/usr/lib/go/src/fmt" +} + +var pkgIndexOnce sync.Once + +var pkgIndex struct { + sync.Mutex + m map[string][]pkg // shortname => []pkg, e.g "http" => "net/http" +} + +func loadPkgIndex() { + pkgIndex.Lock() + pkgIndex.m = make(map[string][]pkg) + pkgIndex.Unlock() + + var wg sync.WaitGroup + for _, path := range build.Default.SrcDirs() { + f, err := os.Open(path) + if err != nil { + fmt.Fprint(os.Stderr, err) + continue + } + children, err := f.Readdir(-1) + f.Close() + if err != nil { + fmt.Fprint(os.Stderr, err) + continue + } + for _, child := range children { + if child.IsDir() { + wg.Add(1) + go func(path, name string) { + defer wg.Done() + loadPkg(&wg, path, name) + }(path, child.Name()) + } + } + } + wg.Wait() +} + +var fset = token.NewFileSet() + +func loadPkg(wg *sync.WaitGroup, root, pkgrelpath string) { + importpath := filepath.ToSlash(pkgrelpath) + shortName := importPathToName(importpath) + + dir := filepath.Join(root, importpath) + pkgIndex.Lock() + pkgIndex.m[shortName] = append(pkgIndex.m[shortName], pkg{ + importpath: importpath, + dir: dir, + }) + pkgIndex.Unlock() + + pkgDir, err := os.Open(dir) + if err != nil { + return + } + children, err := pkgDir.Readdir(-1) + pkgDir.Close() + if err != nil { + return + } + for _, child := range children { + name := child.Name() + if name == "" { + continue + } + if c := name[0]; c == '.' || ('0' <= c && c <= '9') { + continue + } + if child.IsDir() { + wg.Add(1) + go func(root, name string) { + defer wg.Done() + loadPkg(wg, root, name) + }(root, filepath.Join(importpath, name)) + } + } +} + +// loadExports returns a list exports for a package. +var loadExports = loadExportsGoPath + +func loadExportsGoPath(dir string) map[string]bool { + exports := make(map[string]bool) + buildPkg, err := build.ImportDir(dir, 0) + if err != nil { + if strings.Contains(err.Error(), "no buildable Go source files in") { + return nil + } + fmt.Fprintf(os.Stderr, "could not import %q: %v", dir, err) + return nil + } + for _, file := range buildPkg.GoFiles { + f, err := parser.ParseFile(fset, filepath.Join(dir, file), nil, 0) + if err != nil { + fmt.Fprintf(os.Stderr, "could not parse %q: %v", file, err) + continue + } + for name := range f.Scope.Objects { + if ast.IsExported(name) { + exports[name] = true + } + } + } + return exports +} + +// findImport searches for a package with the given symbols. +// If no package is found, findImport returns "". +// Declared as a variable rather than a function so goimports can be easily +// extended by adding a file with an init function. +var findImport = findImportGoPath + +func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) { + + pkgIndexOnce.Do(loadPkgIndex) + + // Collect exports for packages with matching names. + var wg sync.WaitGroup + var pkgsMu sync.Mutex // guards pkgs + // full importpath => exported symbol => True + // e.g. "net/http" => "Client" => True + pkgs := make(map[string]map[string]bool) + pkgIndex.Lock() + for _, pkg := range pkgIndex.m[pkgName] { + wg.Add(1) + go func(importpath, dir string) { + defer wg.Done() + exports := loadExports(dir) + if exports != nil { + pkgsMu.Lock() + pkgs[importpath] = exports + pkgsMu.Unlock() + } + }(pkg.importpath, pkg.dir) + } + pkgIndex.Unlock() + wg.Wait() + + // Filter out packages missing required exported symbols. + for symbol := range symbols { + for importpath, exports := range pkgs { + if !exports[symbol] { + delete(pkgs, importpath) + } + } + } + if len(pkgs) == 0 { + return "", nil + } + + // If there are multiple candidate packages, the shortest one wins. + // This is a heuristic to prefer the standard library (e.g. "bytes") + // over e.g. "github.com/foo/bar/bytes". + shortest := "" + for importPath := range pkgs { + if shortest == "" || len(importPath) < len(shortest) { + shortest = importPath + } + } + return shortest, nil +} + +type visitFn func(node ast.Node) ast.Visitor + +func (fn visitFn) Visit(node ast.Node) ast.Visitor { + return fn(node) +} diff --git a/imports/fix_test.go b/imports/fix_test.go new file mode 100644 index 00000000..e18fd989 --- /dev/null +++ b/imports/fix_test.go @@ -0,0 +1,531 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package imports + +import ( + "bytes" + "flag" + "go/build" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" +) + +var only = flag.String("only", "", "If non-empty, the fix test to run") + +var tests = []struct { + name string + in, out string +}{ + // Adding an import to an existing parenthesized import + { + name: "factored_imports_add", + in: `package foo +import ( + "fmt" +) +func bar() { +var b bytes.Buffer +fmt.Println(b.String()) +} +`, + out: `package foo + +import ( + "bytes" + "fmt" +) + +func bar() { + var b bytes.Buffer + fmt.Println(b.String()) +} +`, + }, + + // Adding an import to an existing parenthesized import, + // verifying it goes into the first section. + { + name: "factored_imports_add_first_sec", + in: `package foo +import ( + "fmt" + + "appengine" +) +func bar() { +var b bytes.Buffer +_ = appengine.IsDevServer +fmt.Println(b.String()) +} +`, + out: `package foo + +import ( + "bytes" + "fmt" + + "appengine" +) + +func bar() { + var b bytes.Buffer + _ = appengine.IsDevServer + fmt.Println(b.String()) +} +`, + }, + + // Adding an import to an existing parenthesized import, + // verifying it goes into the first section. (test 2) + { + name: "factored_imports_add_first_sec_2", + in: `package foo +import ( + "fmt" + + "appengine" +) +func bar() { +_ = math.NaN +_ = fmt.Sprintf +_ = appengine.IsDevServer +} +`, + out: `package foo + +import ( + "fmt" + "math" + + "appengine" +) + +func bar() { + _ = math.NaN + _ = fmt.Sprintf + _ = appengine.IsDevServer +} +`, + }, + + // Adding a new import line, without parens + { + name: "add_import_section", + in: `package foo +func bar() { +var b bytes.Buffer +} +`, + out: `package foo + +import "bytes" + +func bar() { + var b bytes.Buffer +} +`, + }, + + // Adding two new imports, which should make a parenthesized import decl. + { + name: "add_import_paren_section", + in: `package foo +func bar() { +_, _ := bytes.Buffer, zip.NewReader +} +`, + out: `package foo + +import ( + "archive/zip" + "bytes" +) + +func bar() { + _, _ := bytes.Buffer, zip.NewReader +} +`, + }, + + // Make sure we don't add things twice + { + name: "no_double_add", + in: `package foo +func bar() { +_, _ := bytes.Buffer, bytes.NewReader +} +`, + out: `package foo + +import "bytes" + +func bar() { + _, _ := bytes.Buffer, bytes.NewReader +} +`, + }, + + // Remove unused imports, 1 of a factored block + { + name: "remove_unused_1_of_2", + in: `package foo +import ( +"bytes" +"fmt" +) + +func bar() { +_, _ := bytes.Buffer, bytes.NewReader +} +`, + out: `package foo + +import "bytes" + +func bar() { + _, _ := bytes.Buffer, bytes.NewReader +} +`, + }, + + // Remove unused imports, 2 of 2 + { + name: "remove_unused_2_of_2", + in: `package foo +import ( +"bytes" +"fmt" +) + +func bar() { +} +`, + out: `package foo + +func bar() { +} +`, + }, + + // Remove unused imports, 1 of 1 + { + name: "remove_unused_1_of_1", + in: `package foo + +import "fmt" + +func bar() { +} +`, + out: `package foo + +func bar() { +} +`, + }, + + // Don't remove empty imports. + { + name: "dont_remove_empty_imports", + in: `package foo +import ( +_ "image/png" +_ "image/jpeg" +) +`, + out: `package foo + +import ( + _ "image/jpeg" + _ "image/png" +) +`, + }, + + // Don't remove dot imports. + { + name: "dont_remove_dot_imports", + in: `package foo +import ( +. "foo" +. "bar" +) +`, + out: `package foo + +import ( + . "bar" + . "foo" +) +`, + }, + + // Skip refs the parser can resolve. + { + name: "skip_resolved_refs", + in: `package foo + +func f() { + type t struct{ Println func(string) } + fmt := t{Println: func(string) {}} + fmt.Println("foo") +} +`, + out: `package foo + +func f() { + type t struct{ Println func(string) } + fmt := t{Println: func(string) {}} + fmt.Println("foo") +} +`, + }, + + // Do not add a package we already have a resolution for. + { + name: "skip_template", + in: `package foo + +import "html/template" + +func f() { t = template.New("sometemplate") } +`, + out: `package foo + +import "html/template" + +func f() { t = template.New("sometemplate") } +`, + }, + + // Don't touch cgo + { + name: "cgo", + in: `package foo + +/* +#include +*/ +import "C" +`, + out: `package foo + +/* +#include +*/ +import "C" +`, + }, + + // Put some things in their own section + { + name: "make_sections", + in: `package foo + +import ( +"os" +) + +func foo () { +_, _ = os.Args, fmt.Println +_, _ = appengine.FooSomething, user.Current +} +`, + out: `package foo + +import ( + "fmt" + "os" + + "appengine" + "appengine/user" +) + +func foo() { + _, _ = os.Args, fmt.Println + _, _ = appengine.FooSomething, user.Current +} +`, + }, + + // Delete existing empty import block + { + name: "delete_empty_import_block", + in: `package foo + +import () +`, + out: `package foo +`, + }, + + // Use existing empty import block + { + name: "use_empty_import_block", + in: `package foo + +import () + +func f() { + _ = fmt.Println +} +`, + out: `package foo + +import "fmt" + +func f() { + _ = fmt.Println +} +`, + }, + + // Blank line before adding new section. + { + name: "blank_line_before_new_group", + in: `package foo + +import ( + "fmt" + "net" +) + +func f() { + _ = net.Dial + _ = fmt.Printf + _ = snappy.Foo +} +`, + out: `package foo + +import ( + "fmt" + "net" + + "code.google.com/p/snappy-go/snappy" +) + +func f() { + _ = net.Dial + _ = fmt.Printf + _ = snappy.Foo +} +`, + }, + + // Blank line between standard library and third-party stuff. + { + name: "blank_line_separating_std_and_third_party", + in: `package foo + +import ( + "code.google.com/p/snappy-go/snappy" + "fmt" + "net" +) + +func f() { + _ = net.Dial + _ = fmt.Printf + _ = snappy.Foo +} +`, + out: `package foo + +import ( + "fmt" + "net" + + "code.google.com/p/snappy-go/snappy" +) + +func f() { + _ = net.Dial + _ = fmt.Printf + _ = snappy.Foo +} +`, + }, +} + +func TestFixImports(t *testing.T) { + simplePkgs := map[string]string{ + "fmt": "fmt", + "os": "os", + "math": "math", + "appengine": "appengine", + "user": "appengine/user", + "zip": "archive/zip", + "bytes": "bytes", + "snappy": "code.google.com/p/snappy-go/snappy", + } + findImport = func(pkgName string, symbols map[string]bool) (string, error) { + return simplePkgs[pkgName], nil + } + + for _, tt := range tests { + if *only != "" && tt.name != *only { + continue + } + var buf bytes.Buffer + err := processFile("foo.go", strings.NewReader(tt.in), &buf, false) + if err != nil { + t.Errorf("error on %q: %v", tt.name, err) + continue + } + if got := buf.String(); got != tt.out { + t.Errorf("results diff on %q\nGOT:\n%s\nWANT:\n%s\n", tt.name, got, tt.out) + } + } +} + +func TestFindImportGoPath(t *testing.T) { + goroot, err := ioutil.TempDir("", "goimports-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(goroot) + // Test against imaginary bits/bytes package in std lib + bytesDir := filepath.Join(goroot, "src", "pkg", "bits", "bytes") + if err := os.MkdirAll(bytesDir, 0755); err != nil { + t.Fatal(err) + } + bytesSrcPath := filepath.Join(bytesDir, "bytes.go") + bytesPkgPath := "bits/bytes" + bytesSrc := []byte(`package bytes + +type Buffer2 struct {} +`) + if err := ioutil.WriteFile(bytesSrcPath, bytesSrc, 0775); err != nil { + t.Fatal(err) + } + oldGOROOT := build.Default.GOROOT + oldGOPATH := build.Default.GOPATH + build.Default.GOROOT = goroot + build.Default.GOPATH = "" + defer func() { + build.Default.GOROOT = oldGOROOT + build.Default.GOPATH = oldGOPATH + }() + + got, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}) + if err != nil { + t.Fatal(err) + } + if got != bytesPkgPath { + t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, want "%s"`, got, bytesPkgPath) + } + + got, err = findImportGoPath("bytes", map[string]bool{"Missing": true}) + if err != nil { + t.Fatal(err) + } + if got != "" { + t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, want ""`, got) + } +} diff --git a/imports/imports.go b/imports/imports.go new file mode 100644 index 00000000..6da3a512 --- /dev/null +++ b/imports/imports.go @@ -0,0 +1,240 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package imports implements a Go pretty-printer (like package "go/format") +// that also adds or removes import statements as necessary. +package imports + +import ( + "bufio" + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "io" + "regexp" + "strconv" + "strings" + + "code.google.com/p/go.tools/astutil" +) + +// Options specifies options for processing files. +type Options struct { + Fragment bool // Accept fragement of a source file (no package statement) + AllErrors bool // Report all errors (not just the first 10 on different lines) + + Comments bool // Print comments (true if nil *Options provided) + TabIndent bool // Use tabs for indent (true if nil *Options provided) + TabWidth int // Tab width (8 if nil *Options provided) +} + +// Process formats and adjusts imports for the provided file. +// If opt is nil the defaults are used. +func Process(filename string, src []byte, opt *Options) ([]byte, error) { + if opt == nil { + opt = &Options{Comments: true, TabIndent: true, TabWidth: 8} + } + + fileSet := token.NewFileSet() + file, adjust, err := parse(fileSet, filename, src, opt) + if err != nil { + return nil, err + } + + _, err = fixImports(file) + if err != nil { + return nil, err + } + + sortImports(fileSet, file) + imps := astutil.Imports(fileSet, file) + + var spacesBefore []string // import paths we need spaces before + if len(imps) == 1 { + // We have just one block of imports. See if any are in different groups numbers. + lastGroup := -1 + for _, importSpec := range imps[0] { + importPath, _ := strconv.Unquote(importSpec.Path.Value) + groupNum := importGroup(importPath) + if groupNum != lastGroup && lastGroup != -1 { + spacesBefore = append(spacesBefore, importPath) + } + lastGroup = groupNum + } + + } + + printerMode := printer.UseSpaces + if opt.TabIndent { + printerMode |= printer.TabIndent + } + printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth} + + var buf bytes.Buffer + err = printConfig.Fprint(&buf, fileSet, file) + if err != nil { + return nil, err + } + out := buf.Bytes() + if adjust != nil { + out = adjust(src, out) + } + if len(spacesBefore) > 0 { + out = addImportSpaces(bytes.NewReader(out), spacesBefore) + } + return out, nil +} + +// parse parses src, which was read from filename, +// as a Go source file or statement list. +func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) { + parserMode := parser.Mode(0) + if opt.Comments { + parserMode |= parser.ParseComments + } + if opt.AllErrors { + parserMode |= parser.AllErrors + } + + // Try as whole source file. + file, err := parser.ParseFile(fset, filename, src, parserMode) + if err == nil { + return file, nil, nil + } + // If the error is that the source file didn't begin with a + // package line and we accept fragmented input, fall through to + // try as a source fragment. Stop and return on any other error. + if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") { + return nil, nil, err + } + + // If this is a declaration list, make it a source file + // by inserting a package clause. + // Insert using a ;, not a newline, so that the line numbers + // in psrc match the ones in src. + psrc := append([]byte("package p;"), src...) + file, err = parser.ParseFile(fset, filename, psrc, parserMode) + if err == nil { + adjust := func(orig, src []byte) []byte { + // Remove the package clause. + // Gofmt has turned the ; into a \n. + src = src[len("package p\n"):] + return matchSpace(orig, src) + } + return file, adjust, nil + } + // If the error is that the source file didn't begin with a + // declaration, fall through to try as a statement list. + // Stop and return on any other error. + if !strings.Contains(err.Error(), "expected declaration") { + return nil, nil, err + } + + // If this is a statement list, make it a source file + // by inserting a package clause and turning the list + // into a function body. This handles expressions too. + // Insert using a ;, not a newline, so that the line numbers + // in fsrc match the ones in src. + fsrc := append(append([]byte("package p; func _() {"), src...), '}') + file, err = parser.ParseFile(fset, filename, fsrc, parserMode) + if err == nil { + adjust := func(orig, src []byte) []byte { + // Remove the wrapping. + // Gofmt has turned the ; into a \n\n. + src = src[len("package p\n\nfunc _() {"):] + src = src[:len(src)-len("}\n")] + // Gofmt has also indented the function body one level. + // Remove that indent. + src = bytes.Replace(src, []byte("\n\t"), []byte("\n"), -1) + return matchSpace(orig, src) + } + return file, adjust, nil + } + + // Failed, and out of options. + return nil, nil, err +} + +func cutSpace(b []byte) (before, middle, after []byte) { + i := 0 + for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') { + i++ + } + j := len(b) + for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') { + j-- + } + if i <= j { + return b[:i], b[i:j], b[j:] + } + return nil, nil, b[j:] +} + +// matchSpace reformats src to use the same space context as orig. +// 1) If orig begins with blank lines, matchSpace inserts them at the beginning of src. +// 2) matchSpace copies the indentation of the first non-blank line in orig +// to every non-blank line in src. +// 3) matchSpace copies the trailing space from orig and uses it in place +// of src's trailing space. +func matchSpace(orig []byte, src []byte) []byte { + before, _, after := cutSpace(orig) + i := bytes.LastIndex(before, []byte{'\n'}) + before, indent := before[:i+1], before[i+1:] + + _, src, _ = cutSpace(src) + + var b bytes.Buffer + b.Write(before) + for len(src) > 0 { + line := src + if i := bytes.IndexByte(line, '\n'); i >= 0 { + line, src = line[:i+1], line[i+1:] + } else { + src = nil + } + if len(line) > 0 && line[0] != '\n' { // not blank + b.Write(indent) + } + b.Write(line) + } + b.Write(after) + return b.Bytes() +} + +var impLine = regexp.MustCompile(`^\s+(?:\w+\s+)?"(.+)"`) + +func addImportSpaces(r io.Reader, breaks []string) []byte { + var out bytes.Buffer + sc := bufio.NewScanner(r) + inImports := false + done := false + for sc.Scan() { + s := sc.Text() + + if !inImports && !done && strings.HasPrefix(s, "import") { + inImports = true + } + if inImports && (strings.HasPrefix(s, "var") || + strings.HasPrefix(s, "func") || + strings.HasPrefix(s, "const") || + strings.HasPrefix(s, "type")) { + done = true + inImports = false + } + if inImports && len(breaks) > 0 { + if m := impLine.FindStringSubmatch(s); m != nil { + if m[1] == string(breaks[0]) { + out.WriteByte('\n') + breaks = breaks[1:] + } + } + } + + fmt.Fprintln(&out, s) + } + return out.Bytes() +} diff --git a/imports/mkindex.go b/imports/mkindex.go new file mode 100644 index 00000000..755e2394 --- /dev/null +++ b/imports/mkindex.go @@ -0,0 +1,173 @@ +// +build ignore + +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Command mkindex creates the file "pkgindex.go" containing an index of the Go +// standard library. The file is intended to be built as part of the imports +// package, so that the package may be used in environments where a GOROOT is +// not available (such as App Engine). +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/build" + "go/format" + "go/parser" + "go/token" + "io/ioutil" + "log" + "os" + "path" + "path/filepath" + "strings" +) + +var ( + pkgIndex = make(map[string][]pkg) + exports = make(map[string]map[string]bool) +) + +func main() { + // Don't use GOPATH. + ctx := build.Default + ctx.GOPATH = "" + + // Populate pkgIndex global from GOROOT. + for _, path := range ctx.SrcDirs() { + f, err := os.Open(path) + if err != nil { + log.Print(err) + continue + } + children, err := f.Readdir(-1) + f.Close() + if err != nil { + log.Print(err) + continue + } + for _, child := range children { + if child.IsDir() { + loadPkg(path, child.Name()) + } + } + } + // Populate exports global. + for _, ps := range pkgIndex { + for _, p := range ps { + e := loadExports(p.dir) + if e != nil { + exports[p.dir] = e + } + } + } + + // Construct source file. + var buf bytes.Buffer + fmt.Fprint(&buf, pkgIndexHead) + fmt.Fprintf(&buf, "var pkgIndexMaster = %#v\n", pkgIndex) + fmt.Fprintf(&buf, "var exportsMaster = %#v\n", exports) + src := buf.Bytes() + + // Replace main.pkg type name with pkg. + src = bytes.Replace(src, []byte("main.pkg"), []byte("pkg"), -1) + // Replace actual GOROOT with "/go". + src = bytes.Replace(src, []byte(ctx.GOROOT), []byte("/go"), -1) + // Add some line wrapping. + src = bytes.Replace(src, []byte("}, "), []byte("},\n"), -1) + src = bytes.Replace(src, []byte("true, "), []byte("true,\n"), -1) + + var err error + src, err = format.Source(src) + if err != nil { + log.Fatal(err) + } + + // Write out source file. + err = ioutil.WriteFile("pkgindex.go", src, 0644) + if err != nil { + log.Fatal(err) + } +} + +const pkgIndexHead = `package imports + +func init() { + pkgIndexOnce.Do(func() { + pkgIndex.m = pkgIndexMaster + }) + loadExports = func(dir string) map[string]bool { + return exportsMaster[dir] + } +} +` + +type pkg struct { + importpath string // full pkg import path, e.g. "net/http" + dir string // absolute file path to pkg directory e.g. "/usr/lib/go/src/fmt" +} + +var fset = token.NewFileSet() + +func loadPkg(root, importpath string) { + shortName := path.Base(importpath) + if shortName == "testdata" { + return + } + + dir := filepath.Join(root, importpath) + pkgIndex[shortName] = append(pkgIndex[shortName], pkg{ + importpath: importpath, + dir: dir, + }) + + pkgDir, err := os.Open(dir) + if err != nil { + return + } + children, err := pkgDir.Readdir(-1) + pkgDir.Close() + if err != nil { + return + } + for _, child := range children { + name := child.Name() + if name == "" { + continue + } + if c := name[0]; c == '.' || ('0' <= c && c <= '9') { + continue + } + if child.IsDir() { + loadPkg(root, filepath.Join(importpath, name)) + } + } +} + +func loadExports(dir string) map[string]bool { + exports := make(map[string]bool) + buildPkg, err := build.ImportDir(dir, 0) + if err != nil { + if strings.Contains(err.Error(), "no buildable Go source files in") { + return nil + } + log.Printf("could not import %q: %v", dir, err) + return nil + } + for _, file := range buildPkg.GoFiles { + f, err := parser.ParseFile(fset, filepath.Join(dir, file), nil, 0) + if err != nil { + log.Printf("could not parse %q: %v", file, err) + continue + } + for name := range f.Scope.Objects { + if ast.IsExported(name) { + exports[name] = true + } + } + } + return exports +} diff --git a/imports/sortimports.go b/imports/sortimports.go new file mode 100644 index 00000000..68b3dc4e --- /dev/null +++ b/imports/sortimports.go @@ -0,0 +1,214 @@ +// +build go1.2 + +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Hacked up copy of go/ast/import.go + +package imports + +import ( + "go/ast" + "go/token" + "sort" + "strconv" +) + +// sortImports sorts runs of consecutive import lines in import blocks in f. +// It also removes duplicate imports when it is possible to do so without data loss. +func sortImports(fset *token.FileSet, f *ast.File) { + for i, d := range f.Decls { + d, ok := d.(*ast.GenDecl) + if !ok || d.Tok != token.IMPORT { + // Not an import declaration, so we're done. + // Imports are always first. + break + } + + if len(d.Specs) == 0 { + // Empty import block, remove it. + f.Decls = append(f.Decls[:i], f.Decls[i+1:]...) + } + + if !d.Lparen.IsValid() { + // Not a block: sorted by default. + continue + } + + // Identify and sort runs of specs on successive lines. + i := 0 + specs := d.Specs[:0] + for j, s := range d.Specs { + if j > i && fset.Position(s.Pos()).Line > 1+fset.Position(d.Specs[j-1].End()).Line { + // j begins a new run. End this one. + specs = append(specs, sortSpecs(fset, f, d.Specs[i:j])...) + i = j + } + } + specs = append(specs, sortSpecs(fset, f, d.Specs[i:])...) + d.Specs = specs + + // Deduping can leave a blank line before the rparen; clean that up. + if len(d.Specs) > 0 { + lastSpec := d.Specs[len(d.Specs)-1] + lastLine := fset.Position(lastSpec.Pos()).Line + if rParenLine := fset.Position(d.Rparen).Line; rParenLine > lastLine+1 { + fset.File(d.Rparen).MergeLine(rParenLine - 1) + } + } + } +} + +func importPath(s ast.Spec) string { + t, err := strconv.Unquote(s.(*ast.ImportSpec).Path.Value) + if err == nil { + return t + } + return "" +} + +func importName(s ast.Spec) string { + n := s.(*ast.ImportSpec).Name + if n == nil { + return "" + } + return n.Name +} + +func importComment(s ast.Spec) string { + c := s.(*ast.ImportSpec).Comment + if c == nil { + return "" + } + return c.Text() +} + +// collapse indicates whether prev may be removed, leaving only next. +func collapse(prev, next ast.Spec) bool { + if importPath(next) != importPath(prev) || importName(next) != importName(prev) { + return false + } + return prev.(*ast.ImportSpec).Comment == nil +} + +type posSpan struct { + Start token.Pos + End token.Pos +} + +func sortSpecs(fset *token.FileSet, f *ast.File, specs []ast.Spec) []ast.Spec { + // Can't short-circuit here even if specs are already sorted, + // since they might yet need deduplication. + // A lone import, however, may be safely ignored. + if len(specs) <= 1 { + return specs + } + + // Record positions for specs. + pos := make([]posSpan, len(specs)) + for i, s := range specs { + pos[i] = posSpan{s.Pos(), s.End()} + } + + // Identify comments in this range. + // Any comment from pos[0].Start to the final line counts. + lastLine := fset.Position(pos[len(pos)-1].End).Line + cstart := len(f.Comments) + cend := len(f.Comments) + for i, g := range f.Comments { + if g.Pos() < pos[0].Start { + continue + } + if i < cstart { + cstart = i + } + if fset.Position(g.End()).Line > lastLine { + cend = i + break + } + } + comments := f.Comments[cstart:cend] + + // Assign each comment to the import spec preceding it. + importComment := map[*ast.ImportSpec][]*ast.CommentGroup{} + specIndex := 0 + for _, g := range comments { + for specIndex+1 < len(specs) && pos[specIndex+1].Start <= g.Pos() { + specIndex++ + } + s := specs[specIndex].(*ast.ImportSpec) + importComment[s] = append(importComment[s], g) + } + + // Sort the import specs by import path. + // Remove duplicates, when possible without data loss. + // Reassign the import paths to have the same position sequence. + // Reassign each comment to abut the end of its spec. + // Sort the comments by new position. + sort.Sort(byImportSpec(specs)) + + // Dedup. Thanks to our sorting, we can just consider + // adjacent pairs of imports. + deduped := specs[:0] + for i, s := range specs { + if i == len(specs)-1 || !collapse(s, specs[i+1]) { + deduped = append(deduped, s) + } else { + p := s.Pos() + fset.File(p).MergeLine(fset.Position(p).Line) + } + } + specs = deduped + + // Fix up comment positions + for i, s := range specs { + s := s.(*ast.ImportSpec) + if s.Name != nil { + s.Name.NamePos = pos[i].Start + } + s.Path.ValuePos = pos[i].Start + s.EndPos = pos[i].End + for _, g := range importComment[s] { + for _, c := range g.List { + c.Slash = pos[i].End + } + } + } + + sort.Sort(byCommentPos(comments)) + + return specs +} + +type byImportSpec []ast.Spec // slice of *ast.ImportSpec + +func (x byImportSpec) Len() int { return len(x) } +func (x byImportSpec) Swap(i, j int) { x[i], x[j] = x[j], x[i] } +func (x byImportSpec) Less(i, j int) bool { + ipath := importPath(x[i]) + jpath := importPath(x[j]) + + igroup := importGroup(ipath) + jgroup := importGroup(jpath) + if igroup != jgroup { + return igroup < jgroup + } + + if ipath != jpath { + return ipath < jpath + } + iname := importName(x[i]) + jname := importName(x[j]) + + if iname != jname { + return iname < jname + } + return importComment(x[i]) < importComment(x[j]) +} + +type byCommentPos []*ast.CommentGroup + +func (x byCommentPos) Len() int { return len(x) } +func (x byCommentPos) Swap(i, j int) { x[i], x[j] = x[j], x[i] } +func (x byCommentPos) Less(i, j int) bool { return x[i].Pos() < x[j].Pos() } diff --git a/imports/sortimports_compat.go b/imports/sortimports_compat.go new file mode 100644 index 00000000..295f237a --- /dev/null +++ b/imports/sortimports_compat.go @@ -0,0 +1,14 @@ +// +build !go1.2 + +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package imports + +import "go/ast" + +// Go 1.1 users don't get fancy package grouping. +// But this is still gofmt-compliant: + +var sortImports = ast.SortImports