imports: prefer paths imported by sibling files.

Adds an Imports field to packageInfo with the imports used by sibling
files, and uses it preferentially if it matches a missing import.

Example: if foo/foo.go imports "local/log", it's a reasonable assumption
that foo/bar.go will also want "local/log" instead of "log".

Change-Id: Ifb504ed5e00ff18459f19d8598cc2c94099ae563
Reviewed-on: https://go-review.googlesource.com/43454
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Brad Jones 2017-05-12 13:54:58 -07:00 committed by Brad Fitzpatrick
parent e493388965
commit e011c1062a
2 changed files with 85 additions and 2 deletions

View File

@ -68,9 +68,16 @@ func importGroup(importPath string) int {
return 0
}
// importInfo is a summary of information about one import.
type importInfo struct {
Path string // full import path (e.g. "crypto/rand")
Alias string // import alias, if present (e.g. "crand")
}
// packageInfo is a summary of features found in a package.
type packageInfo struct {
Globals map[string]bool // symbol => true
Imports map[string]importInfo // pkg base name or alias => info
}
// dirPackageInfo exposes the dirPackageInfoFile function so that it can be overridden.
@ -94,7 +101,7 @@ func dirPackageInfoFile(pkgName, srcDir, filename string) (*packageInfo, error)
return nil, err
}
info := &packageInfo{Globals: make(map[string]bool)}
info := &packageInfo{Globals: make(map[string]bool), Imports: make(map[string]importInfo)}
for _, fi := range packageFileInfos {
if fi.Name() == fileBase || !strings.HasSuffix(fi.Name(), ".go") {
continue
@ -123,6 +130,16 @@ func dirPackageInfoFile(pkgName, srcDir, filename string) (*packageInfo, error)
info.Globals[valueSpec.Names[0].Name] = true
}
}
for _, imp := range root.Imports {
impInfo := importInfo{Path: strings.Trim(imp.Path.Value, `"`)}
name := path.Base(impInfo.Path)
if imp.Name != nil {
name = strings.Trim(imp.Name.Name, `"`)
impInfo.Alias = name
}
info.Imports[name] = impInfo
}
}
return info, nil
}
@ -217,6 +234,16 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri
}
}
// Fast path, all references already imported.
if len(refs) == 0 {
return nil, nil
}
// Can assume this will be necessary in all cases now.
if !loadedPackageInfo {
packageInfo, _ = dirPackageInfo(f.Name.Name, srcDir, filename)
}
// Search for imports matching potential package references.
searches := 0
type result struct {
@ -227,6 +254,11 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri
results := make(chan result)
for pkgName, symbols := range refs {
go func(pkgName string, symbols map[string]bool) {
sibling := packageInfo.Imports[pkgName]
if sibling.Path != "" {
results <- result{ipath: sibling.Path, name: sibling.Alias}
return
}
ipath, rename, err := findImport(pkgName, symbols, filename)
r := result{ipath: ipath, err: err}
if rename {

View File

@ -1536,6 +1536,57 @@ func TestGlobalImports(t *testing.T) {
})
}
// Tests that sibling files - other files in the same package - can provide an
// import that may not be the default one otherwise.
func TestSiblingImports(t *testing.T) {
// provide is the sibling file that provides the desired import.
const provide = `package siblingimporttest
import "local/log"
func LogSomething() {
log.Print("Something")
}
`
// need is the file being tested that needs the import.
const need = `package siblingimporttest
func LogSomethingElse() {
log.Print("Something else")
}
`
// want is the expected result file
const want = `package siblingimporttest
import "local/log"
func LogSomethingElse() {
log.Print("Something else")
}
`
const pkg = "siblingimporttest"
const siblingFile = pkg + "/needs_import.go"
testConfig{
gopathFiles: map[string]string{
siblingFile: need,
pkg + "/provides_import.go": provide,
},
}.test(t, func(t *goimportTest) {
buf, err := Process(
t.gopath+"/src/"+siblingFile, []byte(need), nil)
if err != nil {
t.Fatal(err)
}
if string(buf) != want {
t.Errorf("wrong output.\ngot:\n%q\nwant:\n%q\n", buf, want)
}
})
}
func strSet(ss []string) map[string]bool {
m := make(map[string]bool)
for _, s := range ss {