imports: extend findImports to return a boolean, rename, that tells

goimports to use the package name as a local qualifier in an import.
For example, if findImports("pkg", "X") returns ("foo/bar",
rename=true), then goimports adds the import line:
  import pkg "foo/bar"
to satisfy uses of pkg.X in the file.

This change doesn't add any implementations of rename=true, though one
is sketched in a TODO.

LGTM=crawshaw
R=crawshaw, rsc
CC=bradfitz, golang-codereviews
https://golang.org/cl/76400050
This commit is contained in:
Sameer Ajmani 2014-03-25 09:37:10 -04:00
parent a781b00b0d
commit a1c1cf19ba
2 changed files with 52 additions and 23 deletions

View File

@ -93,6 +93,7 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
searches := 0 searches := 0
type result struct { type result struct {
ipath string ipath string
name string
err error err error
} }
results := make(chan result) results := make(chan result)
@ -101,8 +102,12 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
continue // skip over packages already imported continue // skip over packages already imported
} }
go func(pkgName string, symbols map[string]bool) { go func(pkgName string, symbols map[string]bool) {
ipath, err := findImport(pkgName, symbols) ipath, rename, err := findImport(pkgName, symbols)
results <- result{ipath, err} r := result{ipath: ipath, err: err}
if rename {
r.name = pkgName
}
results <- r
}(pkgName, symbols) }(pkgName, symbols)
searches++ searches++
} }
@ -112,7 +117,11 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
return nil, result.err return nil, result.err
} }
if result.ipath != "" { if result.ipath != "" {
if result.name != "" {
astutil.AddNamedImport(fset, f, result.name, result.ipath)
} else {
astutil.AddImport(fset, f, result.ipath) astutil.AddImport(fset, f, result.ipath)
}
added = append(added, result.ipath) added = append(added, result.ipath)
} }
} }
@ -270,14 +279,19 @@ func loadExportsGoPath(dir string) map[string]bool {
// extended by adding a file with an init function. // extended by adding a file with an init function.
var findImport = findImportGoPath var findImport = findImportGoPath
func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) { func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, error) {
// Fast path for the standard library. // Fast path for the standard library.
// In the common case we hopefully never have to scan the GOPATH, which can // In the common case we hopefully never have to scan the GOPATH, which can
// be slow with moving disks. // be slow with moving disks.
if pkg, ok := findImportStdlib(pkgName, symbols); ok { if pkg, rename, ok := findImportStdlib(pkgName, symbols); ok {
return pkg, nil return pkg, rename, nil
} }
// TODO(sameer): look at the import lines for other Go files in the
// local directory, since the user is likely to import the same packages
// in the current Go file. Return rename=true when the other Go files
// use a renamed package that's also used in the current file.
pkgIndexOnce.Do(loadPkgIndex) pkgIndexOnce.Do(loadPkgIndex)
// Collect exports for packages with matching names. // Collect exports for packages with matching names.
@ -311,7 +325,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
} }
} }
if len(pkgs) == 0 { if len(pkgs) == 0 {
return "", nil return "", false, nil
} }
// If there are multiple candidate packages, the shortest one wins. // If there are multiple candidate packages, the shortest one wins.
@ -323,7 +337,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
shortest = importPath shortest = importPath
} }
} }
return shortest, nil return shortest, false, nil
} }
type visitFn func(node ast.Node) ast.Visitor type visitFn func(node ast.Node) ast.Visitor
@ -332,17 +346,17 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor {
return fn(node) return fn(node)
} }
func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, ok bool) { func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, rename, ok bool) {
for symbol := range symbols { for symbol := range symbols {
path := stdlib[shortPkg+"."+symbol] path := stdlib[shortPkg+"."+symbol]
if path == "" { if path == "" {
return "", false return "", false, false
} }
if importPath != "" && importPath != path { if importPath != "" && importPath != path {
// Ambiguous. Symbols pointed to different things. // Ambiguous. Symbols pointed to different things.
return "", false return "", false, false
} }
importPath = path importPath = path
} }
return importPath, importPath != "" return importPath, false, importPath != ""
} }

View File

@ -505,6 +505,20 @@ var (
b = gu.a b = gu.a
c = fmt.Printf c = fmt.Printf
) )
`,
},
{
name: "renamed package",
in: `package main
var _ = str.HasPrefix
`,
out: `package main
import str "strings"
var _ = str.HasPrefix
`, `,
}, },
} }
@ -519,9 +533,10 @@ func TestFixImports(t *testing.T) {
"zip": "archive/zip", "zip": "archive/zip",
"bytes": "bytes", "bytes": "bytes",
"snappy": "code.google.com/p/snappy-go/snappy", "snappy": "code.google.com/p/snappy-go/snappy",
"str": "strings",
} }
findImport = func(pkgName string, symbols map[string]bool) (string, error) { findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) {
return simplePkgs[pkgName], nil return simplePkgs[pkgName], pkgName == "str", nil
} }
for _, tt := range tests { for _, tt := range tests {
@ -577,20 +592,20 @@ type Buffer2 struct {}
build.Default.GOPATH = oldGOPATH build.Default.GOPATH = oldGOPATH
}() }()
got, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}) got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if got != bytesPkgPath { if got != bytesPkgPath || rename {
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, want "%s"`, got, bytesPkgPath) t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
} }
got, err = findImportGoPath("bytes", map[string]bool{"Missing": true}) got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if got != "" { if got != "" || rename {
t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, want ""`, got) t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename)
} }
} }
@ -607,12 +622,12 @@ func TestFindImportStdlib(t *testing.T) {
{"ioutil", []string{"Discard"}, "io/ioutil"}, {"ioutil", []string{"Discard"}, "io/ioutil"},
} }
for _, tt := range tests { for _, tt := range tests {
got, ok := findImportStdlib(tt.pkg, strSet(tt.symbols)) got, rename, ok := findImportStdlib(tt.pkg, strSet(tt.symbols))
if (got != "") != ok { if (got != "") != ok {
t.Error("findImportStdlib return value inconsistent") t.Error("findImportStdlib return value inconsistent")
} }
if got != tt.want { if got != tt.want || rename {
t.Errorf("findImportStdlib(%q, %q) = %q; want %q", tt.pkg, tt.symbols, got, tt.want) t.Errorf("findImportStdlib(%q, %q) = %q, %t; want %q, false", tt.pkg, tt.symbols, got, rename, tt.want)
} }
} }
} }