diff --git a/go/ast/astutil/imports.go b/go/ast/astutil/imports.go index 04ad6795..3e4b1953 100644 --- a/go/ast/astutil/imports.go +++ b/go/ast/astutil/imports.go @@ -14,26 +14,26 @@ import ( ) // AddImport adds the import path to the file f, if absent. -func AddImport(fset *token.FileSet, f *ast.File, ipath string) (added bool) { - return AddNamedImport(fset, f, "", ipath) +func AddImport(fset *token.FileSet, f *ast.File, path string) (added bool) { + return AddNamedImport(fset, f, "", path) } -// AddNamedImport adds the import path to the file f, if absent. +// AddNamedImport adds the import with the given name and path to the file f, if absent. // If name is not empty, it is used to rename the import. // // For example, calling // AddNamedImport(fset, f, "pathpkg", "path") // adds // import pathpkg "path" -func AddNamedImport(fset *token.FileSet, f *ast.File, name, ipath string) (added bool) { - if imports(f, ipath) { +func AddNamedImport(fset *token.FileSet, f *ast.File, name, path string) (added bool) { + if imports(f, name, path) { return false } newImport := &ast.ImportSpec{ Path: &ast.BasicLit{ Kind: token.STRING, - Value: strconv.Quote(ipath), + Value: strconv.Quote(path), }, } if name != "" { @@ -43,14 +43,14 @@ func AddNamedImport(fset *token.FileSet, f *ast.File, name, ipath string) (added // Find an import decl to add to. // The goal is to find an existing import // whose import path has the longest shared - // prefix with ipath. + // prefix with path. var ( bestMatch = -1 // length of longest shared prefix lastImport = -1 // index in f.Decls of the file's final import decl impDecl *ast.GenDecl // import decl containing the best match impIndex = -1 // spec index in impDecl containing the best match - isThirdPartyPath = isThirdParty(ipath) + isThirdPartyPath = isThirdParty(path) ) for i, decl := range f.Decls { gen, ok := decl.(*ast.GenDecl) @@ -81,7 +81,7 @@ func AddNamedImport(fset *token.FileSet, f *ast.File, name, ipath string) (added for j, spec := range gen.Specs { impspec := spec.(*ast.ImportSpec) p := importPath(impspec) - n := matchLen(p, ipath) + n := matchLen(p, path) if n > bestMatch || (bestMatch == 0 && !seenAnyThirdParty && isThirdPartyPath) { bestMatch = n impDecl = gen @@ -197,11 +197,13 @@ func isThirdParty(importPath string) bool { } // DeleteImport deletes the import path from the file f, if present. +// If there are duplicate import declarations, all matching ones are deleted. func DeleteImport(fset *token.FileSet, f *ast.File, path string) (deleted bool) { return DeleteNamedImport(fset, f, "", path) } // DeleteNamedImport deletes the import with the given name and path from the file f, if present. +// If there are duplicate import declarations, all matching ones are deleted. func DeleteNamedImport(fset *token.FileSet, f *ast.File, name, path string) (deleted bool) { var delspecs []*ast.ImportSpec var delcomments []*ast.CommentGroup @@ -216,13 +218,7 @@ func DeleteNamedImport(fset *token.FileSet, f *ast.File, name, path string) (del for j := 0; j < len(gen.Specs); j++ { spec := gen.Specs[j] impspec := spec.(*ast.ImportSpec) - if impspec.Name == nil && name != "" { - continue - } - if impspec.Name != nil && impspec.Name.Name != name { - continue - } - if importPath(impspec) != path { + if importName(impspec) != name || importPath(impspec) != path { continue } @@ -383,9 +379,14 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor { return fn } -// imports returns true if f imports path. -func imports(f *ast.File, path string) bool { - return importSpec(f, path) != nil +// imports reports whether f has an import with the specified name and path. +func imports(f *ast.File, name, path string) bool { + for _, s := range f.Imports { + if importName(s) == name && importPath(s) == path { + return true + } + } + return false } // importSpec returns the import spec if f imports path, @@ -399,14 +400,23 @@ func importSpec(f *ast.File, path string) *ast.ImportSpec { return nil } +// importName returns the name of s, +// or "" if the import is not named. +func importName(s *ast.ImportSpec) string { + if s.Name == nil { + return "" + } + return s.Name.Name +} + // importPath returns the unquoted import path of s, // or "" if the path is not properly quoted. func importPath(s *ast.ImportSpec) string { t, err := strconv.Unquote(s.Path.Value) - if err == nil { - return t + if err != nil { + return "" } - return "" + return t } // declImports reports whether gen contains an import of path. diff --git a/go/ast/astutil/imports_test.go b/go/ast/astutil/imports_test.go index da775ef4..1d86e477 100644 --- a/go/ast/astutil/imports_test.go +++ b/go/ast/astutil/imports_test.go @@ -30,7 +30,7 @@ func print(t *testing.T, name string, f *ast.File) string { if err := format.Node(&buf, fset, f); err != nil { t.Fatalf("%s gofmt: %v", name, err) } - return string(buf.Bytes()) + return buf.String() } type test struct { @@ -39,7 +39,7 @@ type test struct { pkg string in string out string - broken bool // known broken + unchanged bool // Expect added/deleted return value to be false. } var addTests = []test{ @@ -58,6 +58,7 @@ import ( "os" ) `, + unchanged: true, }, { name: "import.1", @@ -657,6 +658,117 @@ import ( ) `, }, + + // Issue 28605: Add specified import, even if that import path is imported under another name + { + name: "issue 28605 add unnamed path", + renamedPkg: "", + pkg: "path", + in: `package main + +import ( + . "path" + _ "path" + pathpkg "path" +) +`, + out: `package main + +import ( + "path" + . "path" + _ "path" + pathpkg "path" +) +`, + }, + { + name: "issue 28605 add pathpkg-renamed path", + renamedPkg: "pathpkg", + pkg: "path", + in: `package main + +import ( + "path" + . "path" + _ "path" +) +`, + out: `package main + +import ( + "path" + . "path" + _ "path" + pathpkg "path" +) +`, + }, + { + name: "issue 28605 add blank identifier path", + renamedPkg: "_", + pkg: "path", + in: `package main + +import ( + "path" + . "path" + pathpkg "path" +) +`, + out: `package main + +import ( + "path" + . "path" + _ "path" + pathpkg "path" +) +`, + }, + { + name: "issue 28605 add dot import path", + renamedPkg: ".", + pkg: "path", + in: `package main + +import ( + "path" + _ "path" + pathpkg "path" +) +`, + out: `package main + +import ( + "path" + . "path" + _ "path" + pathpkg "path" +) +`, + }, + + { + name: "duplicate import declarations, add existing one", + renamedPkg: "f", + pkg: "fmt", + in: `package main + +import "fmt" +import "fmt" +import f "fmt" +import f "fmt" +`, + out: `package main + +import "fmt" +import "fmt" +import f "fmt" +import f "fmt" +`, + unchanged: true, + }, } func TestAddImport(t *testing.T) { @@ -664,18 +776,26 @@ func TestAddImport(t *testing.T) { file := parse(t, test.name, test.in) var before bytes.Buffer ast.Fprint(&before, fset, file, nil) - AddNamedImport(fset, file, test.renamedPkg, test.pkg) + added := AddNamedImport(fset, file, test.renamedPkg, test.pkg) if got := print(t, test.name, file); got != test.out { - if test.broken { - t.Logf("%s is known broken:\ngot: %s\nwant: %s", test.name, got, test.out) - } else { - t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out) - } + t.Errorf("first run: %s:\ngot: %s\nwant: %s", test.name, got, test.out) var after bytes.Buffer ast.Fprint(&after, fset, file, nil) - t.Logf("AST before:\n%s\nAST after:\n%s\n", before.String(), after.String()) } + if got, want := added, !test.unchanged; got != want { + t.Errorf("first run: %s: added = %v, want %v", test.name, got, want) + } + + // AddNamedImport should be idempotent. Verify that by calling it again, + // expecting no change to the AST, and the returned added value to always be false. + added = AddNamedImport(fset, file, test.renamedPkg, test.pkg) + if got := print(t, test.name, file); got != test.out { + t.Errorf("second run: %s:\ngot: %s\nwant: %s", test.name, got, test.out) + } + if got, want := added, false; got != want { + t.Errorf("second run: %s: added = %v, want %v", test.name, got, want) + } } } @@ -1405,14 +1525,161 @@ import ( ) `, }, + + // Issue 28605: Delete specified import, even if that import path is imported under another name + { + name: "import.38", + renamedPkg: "", + pkg: "path", + in: `package main + +import ( + "path" + . "path" + _ "path" + pathpkg "path" +) +`, + out: `package main + +import ( + . "path" + _ "path" + pathpkg "path" +) +`, + }, + { + name: "import.39", + renamedPkg: "pathpkg", + pkg: "path", + in: `package main + +import ( + "path" + . "path" + _ "path" + pathpkg "path" +) +`, + out: `package main + +import ( + "path" + . "path" + _ "path" +) +`, + }, + { + name: "import.40", + renamedPkg: "_", + pkg: "path", + in: `package main + +import ( + "path" + . "path" + _ "path" + pathpkg "path" +) +`, + out: `package main + +import ( + "path" + . "path" + pathpkg "path" +) +`, + }, + { + name: "import.41", + renamedPkg: ".", + pkg: "path", + in: `package main + +import ( + "path" + . "path" + _ "path" + pathpkg "path" +) +`, + out: `package main + +import ( + "path" + _ "path" + pathpkg "path" +) +`, + }, + + // Duplicate import declarations, all matching ones are deleted. + { + name: "import.42", + renamedPkg: "f", + pkg: "fmt", + in: `package main + +import "fmt" +import "fmt" +import f "fmt" +import f "fmt" +`, + out: `package main + +import "fmt" +import "fmt" +`, + }, + { + name: "import.43", + renamedPkg: "x", + pkg: "fmt", + in: `package main + +import "fmt" +import "fmt" +import f "fmt" +import f "fmt" +`, + out: `package main + +import "fmt" +import "fmt" +import f "fmt" +import f "fmt" +`, + unchanged: true, + }, } func TestDeleteImport(t *testing.T) { for _, test := range deleteTests { file := parse(t, test.name, test.in) - DeleteNamedImport(fset, file, test.renamedPkg, test.pkg) + var before bytes.Buffer + ast.Fprint(&before, fset, file, nil) + deleted := DeleteNamedImport(fset, file, test.renamedPkg, test.pkg) if got := print(t, test.name, file); got != test.out { - t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out) + t.Errorf("first run: %s:\ngot: %s\nwant: %s", test.name, got, test.out) + var after bytes.Buffer + ast.Fprint(&after, fset, file, nil) + t.Logf("AST before:\n%s\nAST after:\n%s\n", before.String(), after.String()) + } + if got, want := deleted, !test.unchanged; got != want { + t.Errorf("first run: %s: deleted = %v, want %v", test.name, got, want) + } + + // DeleteNamedImport should be idempotent. Verify that by calling it again, + // expecting no change to the AST, and the returned deleted value to always be false. + deleted = DeleteNamedImport(fset, file, test.renamedPkg, test.pkg) + if got := print(t, test.name, file); got != test.out { + t.Errorf("second run: %s:\ngot: %s\nwant: %s", test.name, got, test.out) + } + if got, want := deleted, false; got != want { + t.Errorf("second run: %s: deleted = %v, want %v", test.name, got, want) } } }