diff --git a/imports/fix.go b/imports/fix.go index e99d4429..a241c297 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -68,6 +68,62 @@ func importGroup(importPath string) int { return 0 } +// packageInfo is a summary of features found in a package. +type packageInfo struct { + Globals map[string]bool // symbol => true +} + +// dirPackageInfo gets information from other files in the package. +func dirPackageInfo(srcDir, filename string) (*packageInfo, error) { + considerTests := strings.HasSuffix(filename, "_test.go") + + // Handle file from stdin + if _, err := os.Stat(filename); err != nil { + if os.IsNotExist(err) { + return &packageInfo{}, nil + } + return nil, err + } + + fileBase := filepath.Base(filename) + packageFileInfos, err := ioutil.ReadDir(srcDir) + if err != nil { + return nil, err + } + + info := &packageInfo{Globals: make(map[string]bool)} + for _, fi := range packageFileInfos { + if fi.Name() == fileBase || !strings.HasSuffix(fi.Name(), ".go") { + continue + } + if !considerTests && strings.HasSuffix(fi.Name(), "_test.go") { + continue + } + + fileSet := token.NewFileSet() + root, err := parser.ParseFile(fileSet, filepath.Join(srcDir, fi.Name()), nil, 0) + if err != nil { + continue + } + + for _, decl := range root.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + + for _, spec := range genDecl.Specs { + valueSpec, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + info.Globals[valueSpec.Names[0].Name] = true + } + } + } + return info, nil +} + func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []string, err error) { // refs are a set of possible package references currently unsatisfied by imports. // first key: either base package (e.g. "fmt") or renamed package @@ -86,6 +142,9 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri log.Printf("fixImports(filename=%q), abs=%q, srcDir=%q ...", filename, abs, srcDir) } + var packageInfo *packageInfo + var loadedPackageInfo bool + // collect potential uses of packages. var visitor visitFn visitor = visitFn(func(node ast.Node) ast.Visitor { @@ -117,7 +176,11 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri if refs[pkgName] == nil { refs[pkgName] = make(map[string]bool) } - if decls[pkgName] == nil { + if !loadedPackageInfo { + loadedPackageInfo = true + packageInfo, _ = dirPackageInfo(srcDir, filename) + } + if decls[pkgName] == nil && (packageInfo == nil || !packageInfo.Globals[pkgName]) { refs[pkgName][v.Sel.Name] = true } } diff --git a/imports/fix_test.go b/imports/fix_test.go index 887e1a07..f75a1d63 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -1428,6 +1428,47 @@ func TestGoRootPrefixOfGoPath(t *testing.T) { } +const testGlobalImportsUsesGlobal = `package globalimporttest + +func doSomething() { + t := time.Now() +} +` + +const testGlobalImportsGlobalDecl = `package globalimporttest + +type Time struct{} + +func (t Time) Now() Time { + return Time{} +} + +var time Time +` + +// Tests that package global variables with the same name and function name as +// a function in a separate package do not result in an import which masks +// the global variable +func TestGlobalImports(t *testing.T) { + const pkg = "globalimporttest" + const usesGlobalFile = pkg + "/uses_global.go" + testConfig{ + gopathFiles: map[string]string{ + usesGlobalFile: testGlobalImportsUsesGlobal, + pkg + "/global.go": testGlobalImportsGlobalDecl, + }, + }.test(t, func(t *goimportTest) { + buf, err := Process( + t.gopath+"/src/"+usesGlobalFile, []byte(testGlobalImportsUsesGlobal), nil) + if err != nil { + t.Fatal(err) + } + if string(buf) != testGlobalImportsUsesGlobal { + t.Errorf("wrong output.\ngot:\n%q\nwant:\n%q\n", buf, testGlobalImportsUsesGlobal) + } + }) +} + func strSet(ss []string) map[string]bool { m := make(map[string]bool) for _, s := range ss {