From e011c1062a41d0e2614279f1dc8cf7f1bf6aaac3 Mon Sep 17 00:00:00 2001 From: Brad Jones Date: Fri, 12 May 2017 13:54:58 -0700 Subject: [PATCH] imports: prefer paths imported by sibling files. Adds an Imports field to packageInfo with the imports used by sibling files, and uses it preferentially if it matches a missing import. Example: if foo/foo.go imports "local/log", it's a reasonable assumption that foo/bar.go will also want "local/log" instead of "log". Change-Id: Ifb504ed5e00ff18459f19d8598cc2c94099ae563 Reviewed-on: https://go-review.googlesource.com/43454 Run-TryBot: Brad Fitzpatrick Reviewed-by: Brad Fitzpatrick --- imports/fix.go | 36 ++++++++++++++++++++++++++++++-- imports/fix_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/imports/fix.go b/imports/fix.go index 61a5f062..e772b6df 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -68,9 +68,16 @@ func importGroup(importPath string) int { return 0 } +// importInfo is a summary of information about one import. +type importInfo struct { + Path string // full import path (e.g. "crypto/rand") + Alias string // import alias, if present (e.g. "crand") +} + // packageInfo is a summary of features found in a package. type packageInfo struct { - Globals map[string]bool // symbol => true + Globals map[string]bool // symbol => true + Imports map[string]importInfo // pkg base name or alias => info } // dirPackageInfo exposes the dirPackageInfoFile function so that it can be overridden. @@ -94,7 +101,7 @@ func dirPackageInfoFile(pkgName, srcDir, filename string) (*packageInfo, error) return nil, err } - info := &packageInfo{Globals: make(map[string]bool)} + info := &packageInfo{Globals: make(map[string]bool), Imports: make(map[string]importInfo)} for _, fi := range packageFileInfos { if fi.Name() == fileBase || !strings.HasSuffix(fi.Name(), ".go") { continue @@ -123,6 +130,16 @@ func dirPackageInfoFile(pkgName, srcDir, filename string) (*packageInfo, error) info.Globals[valueSpec.Names[0].Name] = true } } + + for _, imp := range root.Imports { + impInfo := importInfo{Path: strings.Trim(imp.Path.Value, `"`)} + name := path.Base(impInfo.Path) + if imp.Name != nil { + name = strings.Trim(imp.Name.Name, `"`) + impInfo.Alias = name + } + info.Imports[name] = impInfo + } } return info, nil } @@ -217,6 +234,16 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri } } + // Fast path, all references already imported. + if len(refs) == 0 { + return nil, nil + } + + // Can assume this will be necessary in all cases now. + if !loadedPackageInfo { + packageInfo, _ = dirPackageInfo(f.Name.Name, srcDir, filename) + } + // Search for imports matching potential package references. searches := 0 type result struct { @@ -227,6 +254,11 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri results := make(chan result) for pkgName, symbols := range refs { go func(pkgName string, symbols map[string]bool) { + sibling := packageInfo.Imports[pkgName] + if sibling.Path != "" { + results <- result{ipath: sibling.Path, name: sibling.Alias} + return + } ipath, rename, err := findImport(pkgName, symbols, filename) r := result{ipath: ipath, err: err} if rename { diff --git a/imports/fix_test.go b/imports/fix_test.go index 048b9c3a..2026d5c9 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -1536,6 +1536,57 @@ func TestGlobalImports(t *testing.T) { }) } +// Tests that sibling files - other files in the same package - can provide an +// import that may not be the default one otherwise. +func TestSiblingImports(t *testing.T) { + + // provide is the sibling file that provides the desired import. + const provide = `package siblingimporttest + +import "local/log" + +func LogSomething() { + log.Print("Something") +} +` + + // need is the file being tested that needs the import. + const need = `package siblingimporttest + +func LogSomethingElse() { + log.Print("Something else") +} +` + + // want is the expected result file + const want = `package siblingimporttest + +import "local/log" + +func LogSomethingElse() { + log.Print("Something else") +} +` + + const pkg = "siblingimporttest" + const siblingFile = pkg + "/needs_import.go" + testConfig{ + gopathFiles: map[string]string{ + siblingFile: need, + pkg + "/provides_import.go": provide, + }, + }.test(t, func(t *goimportTest) { + buf, err := Process( + t.gopath+"/src/"+siblingFile, []byte(need), nil) + if err != nil { + t.Fatal(err) + } + if string(buf) != want { + t.Errorf("wrong output.\ngot:\n%q\nwant:\n%q\n", buf, want) + } + }) +} + func strSet(ss []string) map[string]bool { m := make(map[string]bool) for _, s := range ss {