diff --git a/internal/imports/fix.go b/internal/imports/fix.go index 76a79e16..2c58d18b 100644 --- a/internal/imports/fix.go +++ b/internal/imports/fix.go @@ -67,6 +67,19 @@ func importGroup(env *ProcessEnv, importPath string) int { return 0 } +type importFixType int + +const ( + addImport importFixType = iota + deleteImport + setImportName +) + +type importFix struct { + info importInfo + fixType importFixType +} + // An importInfo represents a single import statement. type importInfo struct { importPath string // import path, e.g. "crypto/rand". @@ -290,7 +303,7 @@ func (p *pass) importIdentifier(imp *importInfo) string { // load reads in everything necessary to run a pass, and reports whether the // file already has all the imports it needs. It fills in p.missingRefs with the // file's missing symbols, if any, or removes unused imports if not. -func (p *pass) load() bool { +func (p *pass) load() ([]*importFix, bool) { p.knownPackages = map[string]*packageInfo{} p.missingRefs = references{} p.existingImports = map[string]*importInfo{} @@ -320,7 +333,7 @@ func (p *pass) load() bool { if p.env.Debug { p.env.Logf("loading package names: %v", err) } - return false + return nil, false } } for _, imp := range imports { @@ -339,16 +352,16 @@ func (p *pass) load() bool { } } if len(p.missingRefs) != 0 { - return false + return nil, false } return p.fix() } // fix attempts to satisfy missing imports using p.candidates. If it finds -// everything, or if p.lastTry is true, it adds the imports it found, -// removes anything unused, and returns true. -func (p *pass) fix() bool { +// everything, or if p.lastTry is true, it updates fixes to add the imports it found, +// delete anything unused, and update import names, and returns true. +func (p *pass) fix() ([]*importFix, bool) { // Find missing imports. var selected []*importInfo for left, rights := range p.missingRefs { @@ -358,10 +371,11 @@ func (p *pass) fix() bool { } if !p.lastTry && len(selected) != len(p.missingRefs) { - return false + return nil, false } // Found everything, or giving up. Add the new imports and remove any unused. + var fixes []*importFix for _, imp := range p.existingImports { // We deliberately ignore globals here, because we can't be sure // they're in the same package. People do things like put multiple @@ -369,27 +383,77 @@ func (p *pass) fix() bool { // remove imports if they happen to have the same name as a var in // a different package. if _, ok := p.allRefs[p.importIdentifier(imp)]; !ok { - astutil.DeleteNamedImport(p.fset, p.f, imp.name, imp.importPath) + fixes = append(fixes, &importFix{ + info: *imp, + fixType: deleteImport, + }) + continue + } + + // An existing import may need to update its import name to be correct. + if name := p.importSpecName(imp); name != imp.name { + fixes = append(fixes, &importFix{ + info: importInfo{ + name: name, + importPath: imp.importPath, + }, + fixType: setImportName, + }) } } for _, imp := range selected { - astutil.AddNamedImport(p.fset, p.f, imp.name, imp.importPath) + fixes = append(fixes, &importFix{ + info: importInfo{ + name: p.importSpecName(imp), + importPath: imp.importPath, + }, + fixType: addImport, + }) } - if p.loadRealPackageNames { - for _, imp := range p.f.Imports { - if imp.Name != nil { - continue - } - path := strings.Trim(imp.Path.Value, `""`) - ident := p.importIdentifier(&importInfo{importPath: path}) - if ident != importPathToAssumedName(path) { - imp.Name = &ast.Ident{Name: ident, NamePos: imp.Pos()} + return fixes, true +} + +// importSpecName gets the import name of imp in the import spec. +// +// When the import identifier matches the assumed import name, the import name does +// not appear in the import spec. +func (p *pass) importSpecName(imp *importInfo) string { + // If we did not load the real package names, or the name is already set, + // we just return the existing name. + if !p.loadRealPackageNames || imp.name != "" { + return imp.name + } + + ident := p.importIdentifier(imp) + if ident == importPathToAssumedName(imp.importPath) { + return "" // ident not needed since the assumed and real names are the same. + } + return ident +} + +// apply will perform the fixes on f in order. +func apply(fset *token.FileSet, f *ast.File, fixes []*importFix) bool { + for _, fix := range fixes { + switch fix.fixType { + case deleteImport: + astutil.DeleteNamedImport(fset, f, fix.info.name, fix.info.importPath) + case addImport: + astutil.AddNamedImport(fset, f, fix.info.name, fix.info.importPath) + case setImportName: + // Find the matching import path and change the name. + for _, spec := range f.Imports { + path := strings.Trim(spec.Path.Value, `""`) + if path == fix.info.importPath { + spec.Name = &ast.Ident{ + Name: fix.info.name, + NamePos: spec.Pos(), + } + } } } } - return true } @@ -442,10 +506,21 @@ func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) { var fixImports = fixImportsDefault func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) error { - abs, err := filepath.Abs(filename) + fixes, err := getFixes(fset, f, filename, env) if err != nil { return err } + apply(fset, f, fixes) + return err +} + +// getFixes gets the getFixes that need to be made to f in order to fix the imports. +// It does not modify the ast. +func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*importFix, error) { + abs, err := filepath.Abs(filename) + if err != nil { + return nil, err + } srcDir := filepath.Dir(abs) if env.Debug { env.Logf("fixImports(filename=%q), abs=%q, srcDir=%q ...", filename, abs, srcDir) @@ -456,8 +531,8 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *P // complete. We can't add any imports yet, because we don't know // if missing references are actually package vars. p := &pass{fset: fset, f: f, srcDir: srcDir} - if p.load() { - return nil + if fixes, done := p.load(); done { + return fixes, nil } otherFiles := parseOtherFiles(fset, srcDir, filename) @@ -465,15 +540,15 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *P // Second pass: add information from other files in the same package, // like their package vars and imports. p.otherFiles = otherFiles - if p.load() { - return nil + if fixes, done := p.load(); done { + return fixes, nil } // Now we can try adding imports from the stdlib. p.assumeSiblingImportsValid() addStdlibCandidates(p, p.missingRefs) - if p.fix() { - return nil + if fixes, done := p.fix(); done { + return fixes, nil } // Third pass: get real package names where we had previously used @@ -482,25 +557,25 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *P p = &pass{fset: fset, f: f, srcDir: srcDir, env: env} p.loadRealPackageNames = true p.otherFiles = otherFiles - if p.load() { - return nil + if fixes, done := p.load(); done { + return fixes, nil } addStdlibCandidates(p, p.missingRefs) p.assumeSiblingImportsValid() - if p.fix() { - return nil + if fixes, done := p.fix(); done { + return fixes, nil } // Go look for candidates in $GOPATH, etc. We don't necessarily load // the real exports of sibling imports, so keep assuming their contents. if err := addExternalCandidates(p, p.missingRefs, filename); err != nil { - return err + return nil, err } p.lastTry = true - p.fix() - return nil + fixes, _ := p.fix() + return fixes, nil } // ProcessEnv contains environment variables and settings that affect the use of