diff --git a/imports/fix.go b/imports/fix.go index 1b021374..d13836ce 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -54,6 +54,12 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri // decls are the current package imports. key is base package or renamed package. decls := make(map[string]*ast.ImportSpec) + abs, err := filepath.Abs(filename) + if err != nil { + return nil, err + } + srcDir := path.Dir(abs) + // collect potential uses of packages. var visitor visitFn visitor = visitFn(func(node ast.Node) ast.Visitor { @@ -65,7 +71,7 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri if v.Name != nil { decls[v.Name.Name] = v } else { - local := importPathToName(strings.Trim(v.Path.Value, `\"`)) + local := importPathToName(strings.Trim(v.Path.Value, `\"`), srcDir) decls[local] = v } case *ast.SelectorExpr: @@ -152,17 +158,17 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri var importPathToName = importPathToNameGoPath // importPathToNameBasic assumes the package name is the base of import path. -func importPathToNameBasic(importPath string) (packageName string) { +func importPathToNameBasic(importPath, srcDir string) (packageName string) { return path.Base(importPath) } // importPathToNameGoPath finds out the actual package name, as declared in its .go files. // If there's a problem, it falls back to using importPathToNameBasic. -func importPathToNameGoPath(importPath string) (packageName string) { - if buildPkg, err := build.Import(importPath, "", 0); err == nil { +func importPathToNameGoPath(importPath, srcDir string) (packageName string) { + if buildPkg, err := build.Import(importPath, srcDir, 0); err == nil { return buildPkg.Name } else { - return importPathToNameBasic(importPath) + return importPathToNameBasic(importPath, srcDir) } } @@ -260,7 +266,7 @@ func loadPkg(wg *sync.WaitGroup, root, pkgrelpath string) { } } if hasGo { - shortName := importPathToName(importpath) + shortName := importPathToName(importpath, "") pkgIndex.Lock() pkgIndex.m[shortName] = append(pkgIndex.m[shortName], pkg{ importpath: importpath, diff --git a/imports/fix_test.go b/imports/fix_test.go index 3506dbd1..9c89ef67 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -823,6 +823,60 @@ func TestFixImports(t *testing.T) { } } +// Test for correctly identifying the name of a vendored package when it +// differs from its directory name. In this test, the import line +// "mypkg.com/mypkg.v1" would be removed if goimports wasn't able to detect +// that the package name is "mypkg". +func TestFixImportsVendorPackage(t *testing.T) { + // Skip this test on go versions with no vendor support. + if _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor")); err != nil { + t.Skip(err) + } + + newGoPath, err := ioutil.TempDir("", "vendortest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(newGoPath) + + vendoredPath := newGoPath + "/src/mypkg.com/outpkg/vendor/mypkg.com/mypkg.v1" + if err := os.MkdirAll(vendoredPath, 0755); err != nil { + t.Fatal(err) + } + + pkgIndexOnce = &sync.Once{} + oldGOPATH := build.Default.GOPATH + build.Default.GOPATH = newGoPath + defer func() { + build.Default.GOPATH = oldGOPATH + }() + + if err := ioutil.WriteFile(vendoredPath+"/f.go", []byte("package mypkg\nvar Foo = 123\n"), 0666); err != nil { + t.Fatal(err) + } + + input := `package p + +import ( + "fmt" + + "mypkg.com/mypkg.v1" +) + +var ( + _ = fmt.Print + _ = mypkg.Foo +) +` + buf, err := Process(newGoPath+"/src/mypkg.com/outpkg/toformat.go", []byte(input), &Options{}) + if err != nil { + t.Fatal(err) + } + if got := string(buf); got != input { + t.Fatalf("results differ\nGOT:\n%s\nWANT:\n%s\n", got, input) + } +} + func TestFindImportGoPath(t *testing.T) { goroot, err := ioutil.TempDir("", "goimports-") if err != nil {