diff --git a/imports/fix.go b/imports/fix.go index a5f62cd4..c6535f9a 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -93,6 +93,7 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) { searches := 0 type result struct { ipath string + name string err error } results := make(chan result) @@ -101,8 +102,12 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) { continue // skip over packages already imported } go func(pkgName string, symbols map[string]bool) { - ipath, err := findImport(pkgName, symbols) - results <- result{ipath, err} + ipath, rename, err := findImport(pkgName, symbols) + r := result{ipath: ipath, err: err} + if rename { + r.name = pkgName + } + results <- r }(pkgName, symbols) searches++ } @@ -112,7 +117,11 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) { return nil, result.err } if result.ipath != "" { - astutil.AddImport(fset, f, result.ipath) + if result.name != "" { + astutil.AddNamedImport(fset, f, result.name, result.ipath) + } else { + astutil.AddImport(fset, f, result.ipath) + } added = append(added, result.ipath) } } @@ -270,14 +279,19 @@ func loadExportsGoPath(dir string) map[string]bool { // extended by adding a file with an init function. var findImport = findImportGoPath -func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) { +func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, error) { // Fast path for the standard library. // In the common case we hopefully never have to scan the GOPATH, which can // be slow with moving disks. - if pkg, ok := findImportStdlib(pkgName, symbols); ok { - return pkg, nil + if pkg, rename, ok := findImportStdlib(pkgName, symbols); ok { + return pkg, rename, nil } + // TODO(sameer): look at the import lines for other Go files in the + // local directory, since the user is likely to import the same packages + // in the current Go file. Return rename=true when the other Go files + // use a renamed package that's also used in the current file. + pkgIndexOnce.Do(loadPkgIndex) // Collect exports for packages with matching names. @@ -311,7 +325,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) { } } if len(pkgs) == 0 { - return "", nil + return "", false, nil } // If there are multiple candidate packages, the shortest one wins. @@ -323,7 +337,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) { shortest = importPath } } - return shortest, nil + return shortest, false, nil } type visitFn func(node ast.Node) ast.Visitor @@ -332,17 +346,17 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor { return fn(node) } -func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, ok bool) { +func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, rename, ok bool) { for symbol := range symbols { path := stdlib[shortPkg+"."+symbol] if path == "" { - return "", false + return "", false, false } if importPath != "" && importPath != path { // Ambiguous. Symbols pointed to different things. - return "", false + return "", false, false } importPath = path } - return importPath, importPath != "" + return importPath, false, importPath != "" } diff --git a/imports/fix_test.go b/imports/fix_test.go index 9b7ddde7..f96f8535 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -505,6 +505,20 @@ var ( b = gu.a c = fmt.Printf ) +`, + }, + + { + name: "renamed package", + in: `package main + +var _ = str.HasPrefix +`, + out: `package main + +import str "strings" + +var _ = str.HasPrefix `, }, } @@ -519,9 +533,10 @@ func TestFixImports(t *testing.T) { "zip": "archive/zip", "bytes": "bytes", "snappy": "code.google.com/p/snappy-go/snappy", + "str": "strings", } - findImport = func(pkgName string, symbols map[string]bool) (string, error) { - return simplePkgs[pkgName], nil + findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) { + return simplePkgs[pkgName], pkgName == "str", nil } for _, tt := range tests { @@ -577,20 +592,20 @@ type Buffer2 struct {} build.Default.GOPATH = oldGOPATH }() - got, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}) + got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}) if err != nil { t.Fatal(err) } - if got != bytesPkgPath { - t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, want "%s"`, got, bytesPkgPath) + if got != bytesPkgPath || rename { + t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath) } - got, err = findImportGoPath("bytes", map[string]bool{"Missing": true}) + got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}) if err != nil { t.Fatal(err) } - if got != "" { - t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, want ""`, got) + if got != "" || rename { + t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename) } } @@ -607,12 +622,12 @@ func TestFindImportStdlib(t *testing.T) { {"ioutil", []string{"Discard"}, "io/ioutil"}, } for _, tt := range tests { - got, ok := findImportStdlib(tt.pkg, strSet(tt.symbols)) + got, rename, ok := findImportStdlib(tt.pkg, strSet(tt.symbols)) if (got != "") != ok { t.Error("findImportStdlib return value inconsistent") } - if got != tt.want { - t.Errorf("findImportStdlib(%q, %q) = %q; want %q", tt.pkg, tt.symbols, got, tt.want) + if got != tt.want || rename { + t.Errorf("findImportStdlib(%q, %q) = %q, %t; want %q, false", tt.pkg, tt.symbols, got, rename, tt.want) } } }