diff --git a/imports/fix.go b/imports/fix.go index 75c3b755..085d8aa2 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -233,7 +233,7 @@ type pass struct { fset *token.FileSet // fset used to parse f and its siblings. f *ast.File // the file being fixed. srcDir string // the directory containing f. - useGoPackages bool // use go/packages to load package information. + fixEnv *fixEnv // the environment to use for go commands, etc. loadRealPackageNames bool // if true, load package names from disk rather than guessing them. otherFiles []*ast.File // sibling files. @@ -258,9 +258,9 @@ func (p *pass) loadPackageNames(imports []*importInfo) error { unknown = append(unknown, imp.importPath) } - if !p.useGoPackages { + if !p.fixEnv.shouldUseGoPackages() { for _, path := range unknown { - name := importPathToName(path, p.srcDir) + name := importPathToName(p.fixEnv, path, p.srcDir) if name == "" { continue } @@ -272,7 +272,7 @@ func (p *pass) loadPackageNames(imports []*importInfo) error { return nil } - cfg := newPackagesConfig(packages.LoadFiles) + cfg := p.fixEnv.newPackagesConfig(packages.LoadFiles) pkgs, err := packages.Load(cfg, unknown...) if err != nil { return err @@ -328,7 +328,9 @@ func (p *pass) load() bool { // f's imports by the identifier they introduce. imports := collectImports(p.f) if p.loadRealPackageNames { - p.loadPackageNames(append(imports, p.candidates...)) + if err := p.loadPackageNames(append(imports, p.candidates...)); err != nil { + panic(err) + } } for _, imp := range imports { p.existingImports[p.importIdentifier(imp)] = imp @@ -448,7 +450,7 @@ func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) { // easily be extended by adding a file with an init function. var fixImports = fixImportsDefault -func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string) error { +func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *fixEnv) error { abs, err := filepath.Abs(filename) if err != nil { return err @@ -462,7 +464,7 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string) error // derive package names from import paths, see if the file is already // complete. We can't add any imports yet, because we don't know // if missing references are actually package vars. - p := &pass{fset: fset, f: f, srcDir: srcDir} + p := &pass{fset: fset, f: f, srcDir: srcDir, fixEnv: env} if p.load() { return nil } @@ -471,7 +473,7 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string) error // Second pass: add information from other files in the same package, // like their package vars and imports. - p = &pass{fset: fset, f: f, srcDir: srcDir} + p = &pass{fset: fset, f: f, srcDir: srcDir, fixEnv: env} p.otherFiles = otherFiles if p.load() { return nil @@ -484,13 +486,9 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string) error return nil } - // The only things that use go/packages happen in the third pass, - // so we can delay calling go env until this point. - useGoPackages := shouldUseGoPackages() - // Third pass: get real package names where we had previously used // the naive algorithm. - p = &pass{fset: fset, f: f, srcDir: srcDir, useGoPackages: useGoPackages} + p = &pass{fset: fset, f: f, srcDir: srcDir, fixEnv: env} p.loadRealPackageNames = true p.otherFiles = otherFiles if p.load() { @@ -514,35 +512,66 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string) error return nil } -// Values controlling the use of go/packages, for testing only. -var forceGoPackages, _ = strconv.ParseBool(os.Getenv("GOIMPORTSFORCEGOPACKAGES")) -var goPackagesDir string -var go111ModuleEnv string +// fixEnv contains environment variables and settings that affect the use of +// the go command, the go/build package, etc. +type fixEnv struct { + // If non-empty, these will be used instead of the + // process-wide values. + GOPATH, GOROOT, GO111MODULE string + WorkingDir string -func shouldUseGoPackages() bool { - if forceGoPackages { + // If true, use go/packages regardless of the environment. + ForceGoPackages bool + + ranGoEnv bool + gomod string +} + +func (e *fixEnv) env() []string { + env := os.Environ() + add := func(k, v string) { + if v != "" { + env = append(env, k+"="+v) + } + } + add("GOPATH", e.GOPATH) + add("GOROOT", e.GOROOT) + add("GO111MODULE", e.GO111MODULE) + return env +} + +func (e *fixEnv) shouldUseGoPackages() bool { + if e.ForceGoPackages { return true } - cmd := exec.Command("go", "env", "GOMOD") - cmd.Dir = goPackagesDir - out, err := cmd.Output() - if err != nil { - return false + if !e.ranGoEnv { + e.ranGoEnv = true + cmd := exec.Command("go", "env", "GOMOD") + cmd.Dir = e.WorkingDir + cmd.Env = e.env() + out, err := cmd.Output() + if err != nil { + return false + } + e.gomod = string(bytes.TrimSpace(out)) } - return len(bytes.TrimSpace(out)) > 0 + return e.gomod != "" } -func newPackagesConfig(mode packages.LoadMode) *packages.Config { - cfg := &packages.Config{ +func (e *fixEnv) newPackagesConfig(mode packages.LoadMode) *packages.Config { + return &packages.Config{ Mode: mode, - Dir: goPackagesDir, - Env: append(os.Environ(), "GOROOT="+build.Default.GOROOT, "GOPATH="+build.Default.GOPATH), + Dir: e.WorkingDir, + Env: e.env(), } - if go111ModuleEnv != "" { - cfg.Env = append(cfg.Env, "GO111MODULE="+go111ModuleEnv) - } - return cfg +} + +func (e *fixEnv) buildContext() *build.Context { + ctx := build.Default + ctx.GOROOT = e.GOROOT + ctx.GOPATH = e.GOPATH + return &ctx } func addStdlibCandidates(pass *pass, refs map[string]map[string]bool) { @@ -566,13 +595,13 @@ func addStdlibCandidates(pass *pass, refs map[string]map[string]bool) { } } -func scanGoPackages(refs map[string]map[string]bool) ([]*pkg, error) { +func scanGoPackages(env *fixEnv, refs map[string]map[string]bool) ([]*pkg, error) { var loadQueries []string for pkgName := range refs { loadQueries = append(loadQueries, "name="+pkgName) } sort.Strings(loadQueries) - cfg := newPackagesConfig(packages.LoadFiles) + cfg := env.newPackagesConfig(packages.LoadFiles) goPackages, err := packages.Load(cfg, loadQueries...) if err != nil { return nil, err @@ -593,14 +622,14 @@ var addExternalCandidates = addExternalCandidatesDefault func addExternalCandidatesDefault(pass *pass, refs map[string]map[string]bool, filename string) error { var dirScan []*pkg - if pass.useGoPackages { + if pass.fixEnv.shouldUseGoPackages() { var err error - dirScan, err = scanGoPackages(refs) + dirScan, err = scanGoPackages(pass.fixEnv, refs) if err != nil { return err } } else { - dirScan = scanGoDirs() + dirScan = scanGoDirs(pass.fixEnv) } // Search for imports matching potential package references. @@ -625,7 +654,7 @@ func addExternalCandidatesDefault(pass *pass, refs map[string]map[string]bool, f go func(pkgName string, symbols map[string]bool) { defer wg.Done() - found, err := findImport(ctx, dirScan, pkgName, symbols, filename) + found, err := findImport(ctx, pass.fixEnv, dirScan, pkgName, symbols, filename) if err != nil { firstErrOnce.Do(func() { @@ -678,13 +707,13 @@ func importPathToNameBasic(importPath, srcDir string) (packageName string) { // importPathToNameGoPath finds out the actual package name, as declared in its .go files. // If there's a problem, it returns "". -func importPathToName(importPath, srcDir string) (packageName string) { +func importPathToName(env *fixEnv, importPath, srcDir string) (packageName string) { // Fast path for standard library without going to disk. if _, ok := stdlib[importPath]; ok { return path.Base(importPath) // stdlib packages always match their paths. } - pkgName, err := importPathToNameGoPathParse(importPath, srcDir) + pkgName, err := importPathToNameGoPathParse(env, importPath, srcDir) if Debug { log.Printf("importPathToNameGoPathParse(%q, srcDir=%q) = %q, %v", importPath, srcDir, pkgName, err) } @@ -698,8 +727,8 @@ func importPathToName(importPath, srcDir string) (packageName string) { // the only thing desired is the package name. It uses build.FindOnly // to find the directory and then only parses one file in the package, // trusting that the files in the directory are consistent. -func importPathToNameGoPathParse(importPath, srcDir string) (packageName string, err error) { - buildPkg, err := build.Import(importPath, srcDir, build.FindOnly) +func importPathToNameGoPathParse(env *fixEnv, importPath, srcDir string) (packageName string, err error) { + buildPkg, err := env.buildContext().Import(importPath, srcDir, build.FindOnly) if err != nil { return "", err } @@ -798,7 +827,7 @@ func distance(basepath, targetpath string) int { } // scanGoDirs populates the dirScan map for GOPATH and GOROOT. -func scanGoDirs() []*pkg { +func scanGoDirs(env *fixEnv) []*pkg { dupCheck := make(map[string]bool) var result []*pkg @@ -818,7 +847,7 @@ func scanGoDirs() []*pkg { dir: dir, }) } - gopathwalk.Walk(gopathwalk.SrcDirsRoots(), add, gopathwalk.Options{Debug: Debug, ModulesEnabled: false}) + gopathwalk.Walk(gopathwalk.SrcDirsRoots(env.buildContext()), add, gopathwalk.Options{Debug: Debug, ModulesEnabled: false}) return result } @@ -837,7 +866,7 @@ func VendorlessPath(ipath string) string { // loadExports returns the set of exported symbols in the package at dir. // It returns nil on error or if the package name in dir does not match expectPackage. -func loadExports(ctx context.Context, expectPackage string, pkg *pkg) (map[string]bool, error) { +func loadExports(ctx context.Context, env *fixEnv, expectPackage string, pkg *pkg) (map[string]bool, error) { if Debug { log.Printf("loading exports in dir %s (seeking package %s)", pkg.dir, expectPackage) } @@ -871,7 +900,7 @@ func loadExports(ctx context.Context, expectPackage string, pkg *pkg) (map[strin if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") { continue } - match, err := build.Default.MatchFile(pkg.dir, fi.Name()) + match, err := env.buildContext().MatchFile(pkg.dir, fi.Name()) if err != nil || !match { continue } @@ -924,7 +953,7 @@ func loadExports(ctx context.Context, expectPackage string, pkg *pkg) (map[strin // findImport searches for a package with the given symbols. // If no package is found, findImport returns ("", false, nil) -func findImport(ctx context.Context, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) { +func findImport(ctx context.Context, env *fixEnv, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) { pkgDir, err := filepath.Abs(filename) if err != nil { return nil, err @@ -986,7 +1015,7 @@ func findImport(ctx context.Context, dirScan []*pkg, pkgName string, symbols map wg.Done() }() - exports, err := loadExports(ctx, pkgName, c.pkg) + exports, err := loadExports(ctx, env, pkgName, c.pkg) if err != nil { if Debug { log.Printf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err) diff --git a/imports/fix_test.go b/imports/fix_test.go index 1006e62d..dd9fe451 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -6,7 +6,6 @@ package imports import ( "fmt" - "go/build" "path/filepath" "runtime" "strings" @@ -1520,6 +1519,7 @@ func (c testConfig) test(t *testing.T, fn func(*goimportTest)) { t.Run(kind, func(t *testing.T) { t.Helper() + forceGoPackages := false var exporter packagestest.Exporter switch kind { case "GOPATH": @@ -1545,30 +1545,15 @@ func (c testConfig) test(t *testing.T, fn func(*goimportTest)) { env[k] = v } - goroot := env["GOROOT"] - gopath := env["GOPATH"] - - oldGOPATH := build.Default.GOPATH - oldGOROOT := build.Default.GOROOT - oldCompiler := build.Default.Compiler - build.Default.GOROOT = goroot - build.Default.GOPATH = gopath - build.Default.Compiler = "gc" - goPackagesDir = exported.Config.Dir - go111ModuleEnv = env["GO111MODULE"] - - defer func() { - build.Default.GOPATH = oldGOPATH - build.Default.GOROOT = oldGOROOT - build.Default.Compiler = oldCompiler - go111ModuleEnv = "" - goPackagesDir = "" - forceGoPackages = false - }() - it := &goimportTest{ - T: t, - gopath: gopath, + T: t, + fixEnv: &fixEnv{ + GOROOT: env["GOROOT"], + GOPATH: env["GOPATH"], + GO111MODULE: env["GO111MODULE"], + WorkingDir: exported.Config.Dir, + ForceGoPackages: forceGoPackages, + }, exported: exported, } fn(it) @@ -1586,7 +1571,7 @@ func (c testConfig) processTest(t *testing.T, module, file string, contents []by type goimportTest struct { *testing.T - gopath string + fixEnv *fixEnv exported *packagestest.Exported } @@ -1596,7 +1581,7 @@ func (t *goimportTest) process(module, file string, contents []byte, opts *Optio if f == "" { t.Fatalf("%v not found in exported files (typo in filename?)", file) } - buf, err := Process(f, contents, opts) + buf, err := process(f, contents, opts, t.fixEnv) if err != nil { t.Fatalf("Process() = %v", err) } @@ -1818,7 +1803,7 @@ func TestImportPathToNameGoPathParse(t *testing.T) { }, }, }.test(t, func(t *goimportTest) { - got, err := importPathToNameGoPathParse("example.net/pkg", filepath.Join(t.gopath, "src", "other.net")) + got, err := importPathToNameGoPathParse(t.fixEnv, "example.net/pkg", filepath.Join(t.fixEnv.GOPATH, "src", "other.net")) if err != nil { t.Fatal(err) } diff --git a/imports/imports.go b/imports/imports.go index 717a6f3a..07101cb8 100644 --- a/imports/imports.go +++ b/imports/imports.go @@ -13,6 +13,7 @@ import ( "bytes" "fmt" "go/ast" + "go/build" "go/format" "go/parser" "go/printer" @@ -45,6 +46,11 @@ type Options struct { // so it is important that filename be accurate. // To process data ``as if'' it were in filename, pass the data as a non-nil src. func Process(filename string, src []byte, opt *Options) ([]byte, error) { + env := &fixEnv{GOPATH: build.Default.GOPATH, GOROOT: build.Default.GOROOT} + return process(filename, src, opt, env) +} + +func process(filename string, src []byte, opt *Options, env *fixEnv) ([]byte, error) { if opt == nil { opt = &Options{Comments: true, TabIndent: true, TabWidth: 8} } @@ -63,7 +69,7 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) { } if !opt.FormatOnly { - if err := fixImports(fileSet, file, filename); err != nil { + if err := fixImports(fileSet, file, filename, env); err != nil { return nil, err } } diff --git a/internal/gopathwalk/walk.go b/internal/gopathwalk/walk.go index a561f9f4..488088b5 100644 --- a/internal/gopathwalk/walk.go +++ b/internal/gopathwalk/walk.go @@ -44,10 +44,10 @@ type Root struct { } // SrcDirsRoots returns the roots from build.Default.SrcDirs(). Not modules-compatible. -func SrcDirsRoots() []Root { +func SrcDirsRoots(ctx *build.Context) []Root { var roots []Root - roots = append(roots, Root{filepath.Join(build.Default.GOROOT, "src"), RootGOROOT}) - for _, p := range filepath.SplitList(build.Default.GOPATH) { + roots = append(roots, Root{filepath.Join(ctx.GOROOT, "src"), RootGOROOT}) + for _, p := range filepath.SplitList(ctx.GOPATH) { roots = append(roots, Root{filepath.Join(p, "src"), RootGOPATH}) } return roots