imports: extend findImports to return a boolean, rename, that tells
goimports to use the package name as a local qualifier in an import. For example, if findImports("pkg", "X") returns ("foo/bar", rename=true), then goimports adds the import line: import pkg "foo/bar" to satisfy uses of pkg.X in the file. This change doesn't add any implementations of rename=true, though one is sketched in a TODO. LGTM=crawshaw R=crawshaw, rsc CC=bradfitz, golang-codereviews https://golang.org/cl/76400050
This commit is contained in:
parent
a781b00b0d
commit
a1c1cf19ba
|
@ -93,6 +93,7 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
|
||||||
searches := 0
|
searches := 0
|
||||||
type result struct {
|
type result struct {
|
||||||
ipath string
|
ipath string
|
||||||
|
name string
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
results := make(chan result)
|
results := make(chan result)
|
||||||
|
@ -101,8 +102,12 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
|
||||||
continue // skip over packages already imported
|
continue // skip over packages already imported
|
||||||
}
|
}
|
||||||
go func(pkgName string, symbols map[string]bool) {
|
go func(pkgName string, symbols map[string]bool) {
|
||||||
ipath, err := findImport(pkgName, symbols)
|
ipath, rename, err := findImport(pkgName, symbols)
|
||||||
results <- result{ipath, err}
|
r := result{ipath: ipath, err: err}
|
||||||
|
if rename {
|
||||||
|
r.name = pkgName
|
||||||
|
}
|
||||||
|
results <- r
|
||||||
}(pkgName, symbols)
|
}(pkgName, symbols)
|
||||||
searches++
|
searches++
|
||||||
}
|
}
|
||||||
|
@ -112,7 +117,11 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
|
||||||
return nil, result.err
|
return nil, result.err
|
||||||
}
|
}
|
||||||
if result.ipath != "" {
|
if result.ipath != "" {
|
||||||
astutil.AddImport(fset, f, result.ipath)
|
if result.name != "" {
|
||||||
|
astutil.AddNamedImport(fset, f, result.name, result.ipath)
|
||||||
|
} else {
|
||||||
|
astutil.AddImport(fset, f, result.ipath)
|
||||||
|
}
|
||||||
added = append(added, result.ipath)
|
added = append(added, result.ipath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -270,14 +279,19 @@ func loadExportsGoPath(dir string) map[string]bool {
|
||||||
// extended by adding a file with an init function.
|
// extended by adding a file with an init function.
|
||||||
var findImport = findImportGoPath
|
var findImport = findImportGoPath
|
||||||
|
|
||||||
func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
|
func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, error) {
|
||||||
// Fast path for the standard library.
|
// Fast path for the standard library.
|
||||||
// In the common case we hopefully never have to scan the GOPATH, which can
|
// In the common case we hopefully never have to scan the GOPATH, which can
|
||||||
// be slow with moving disks.
|
// be slow with moving disks.
|
||||||
if pkg, ok := findImportStdlib(pkgName, symbols); ok {
|
if pkg, rename, ok := findImportStdlib(pkgName, symbols); ok {
|
||||||
return pkg, nil
|
return pkg, rename, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(sameer): look at the import lines for other Go files in the
|
||||||
|
// local directory, since the user is likely to import the same packages
|
||||||
|
// in the current Go file. Return rename=true when the other Go files
|
||||||
|
// use a renamed package that's also used in the current file.
|
||||||
|
|
||||||
pkgIndexOnce.Do(loadPkgIndex)
|
pkgIndexOnce.Do(loadPkgIndex)
|
||||||
|
|
||||||
// Collect exports for packages with matching names.
|
// Collect exports for packages with matching names.
|
||||||
|
@ -311,7 +325,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(pkgs) == 0 {
|
if len(pkgs) == 0 {
|
||||||
return "", nil
|
return "", false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are multiple candidate packages, the shortest one wins.
|
// If there are multiple candidate packages, the shortest one wins.
|
||||||
|
@ -323,7 +337,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
|
||||||
shortest = importPath
|
shortest = importPath
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return shortest, nil
|
return shortest, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type visitFn func(node ast.Node) ast.Visitor
|
type visitFn func(node ast.Node) ast.Visitor
|
||||||
|
@ -332,17 +346,17 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor {
|
||||||
return fn(node)
|
return fn(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, ok bool) {
|
func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, rename, ok bool) {
|
||||||
for symbol := range symbols {
|
for symbol := range symbols {
|
||||||
path := stdlib[shortPkg+"."+symbol]
|
path := stdlib[shortPkg+"."+symbol]
|
||||||
if path == "" {
|
if path == "" {
|
||||||
return "", false
|
return "", false, false
|
||||||
}
|
}
|
||||||
if importPath != "" && importPath != path {
|
if importPath != "" && importPath != path {
|
||||||
// Ambiguous. Symbols pointed to different things.
|
// Ambiguous. Symbols pointed to different things.
|
||||||
return "", false
|
return "", false, false
|
||||||
}
|
}
|
||||||
importPath = path
|
importPath = path
|
||||||
}
|
}
|
||||||
return importPath, importPath != ""
|
return importPath, false, importPath != ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -505,6 +505,20 @@ var (
|
||||||
b = gu.a
|
b = gu.a
|
||||||
c = fmt.Printf
|
c = fmt.Printf
|
||||||
)
|
)
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "renamed package",
|
||||||
|
in: `package main
|
||||||
|
|
||||||
|
var _ = str.HasPrefix
|
||||||
|
`,
|
||||||
|
out: `package main
|
||||||
|
|
||||||
|
import str "strings"
|
||||||
|
|
||||||
|
var _ = str.HasPrefix
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -519,9 +533,10 @@ func TestFixImports(t *testing.T) {
|
||||||
"zip": "archive/zip",
|
"zip": "archive/zip",
|
||||||
"bytes": "bytes",
|
"bytes": "bytes",
|
||||||
"snappy": "code.google.com/p/snappy-go/snappy",
|
"snappy": "code.google.com/p/snappy-go/snappy",
|
||||||
|
"str": "strings",
|
||||||
}
|
}
|
||||||
findImport = func(pkgName string, symbols map[string]bool) (string, error) {
|
findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) {
|
||||||
return simplePkgs[pkgName], nil
|
return simplePkgs[pkgName], pkgName == "str", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -577,20 +592,20 @@ type Buffer2 struct {}
|
||||||
build.Default.GOPATH = oldGOPATH
|
build.Default.GOPATH = oldGOPATH
|
||||||
}()
|
}()
|
||||||
|
|
||||||
got, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
|
got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if got != bytesPkgPath {
|
if got != bytesPkgPath || rename {
|
||||||
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, want "%s"`, got, bytesPkgPath)
|
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
|
got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if got != "" {
|
if got != "" || rename {
|
||||||
t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, want ""`, got)
|
t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -607,12 +622,12 @@ func TestFindImportStdlib(t *testing.T) {
|
||||||
{"ioutil", []string{"Discard"}, "io/ioutil"},
|
{"ioutil", []string{"Discard"}, "io/ioutil"},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
got, ok := findImportStdlib(tt.pkg, strSet(tt.symbols))
|
got, rename, ok := findImportStdlib(tt.pkg, strSet(tt.symbols))
|
||||||
if (got != "") != ok {
|
if (got != "") != ok {
|
||||||
t.Error("findImportStdlib return value inconsistent")
|
t.Error("findImportStdlib return value inconsistent")
|
||||||
}
|
}
|
||||||
if got != tt.want {
|
if got != tt.want || rename {
|
||||||
t.Errorf("findImportStdlib(%q, %q) = %q; want %q", tt.pkg, tt.symbols, got, tt.want)
|
t.Errorf("findImportStdlib(%q, %q) = %q, %t; want %q, false", tt.pkg, tt.symbols, got, rename, tt.want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue