From 0d4ee40f21419f2ed88980d978011b133d8bea8e Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Fri, 8 Nov 2013 14:51:15 -0500 Subject: [PATCH] go.tools/astutil: add Imports, which returns imports grouped by spacing. R=bradfitz CC=golang-dev https://golang.org/cl/23660045 --- astutil/imports.go | 30 ++++++++++++ astutil/imports_test.go | 102 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/astutil/imports.go b/astutil/imports.go index c07676cf..9ce61248 100644 --- a/astutil/imports.go +++ b/astutil/imports.go @@ -323,3 +323,33 @@ func isTopName(n ast.Expr, name string) bool { id, ok := n.(*ast.Ident) return ok && id.Name == name && id.Obj == nil } + +// Imports returns the file imports grouped by paragraph. +func Imports(fset *token.FileSet, f *ast.File) [][]*ast.ImportSpec { + var groups [][]*ast.ImportSpec + + for _, decl := range f.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.IMPORT { + break + } + + group := []*ast.ImportSpec{} + + var lastLine int + for _, spec := range genDecl.Specs { + importSpec := spec.(*ast.ImportSpec) + pos := importSpec.Path.ValuePos + line := fset.Position(pos).Line + if lastLine > 0 && pos > 0 && line-lastLine > 1 { + groups = append(groups, group) + group = []*ast.ImportSpec{} + } + group = append(group, importSpec) + lastLine = line + } + groups = append(groups, group) + } + + return groups +} diff --git a/astutil/imports_test.go b/astutil/imports_test.go index e47889f7..3ecfe518 100644 --- a/astutil/imports_test.go +++ b/astutil/imports_test.go @@ -6,7 +6,8 @@ import ( "go/format" "go/parser" "go/token" - + "reflect" + "strconv" "testing" ) @@ -528,3 +529,102 @@ func TestRenameTop(t *testing.T) { } } } + +var importsTests = []struct { + name string + in string + want [][]string +}{ + { + name: "no packages", + in: `package foo +`, + want: nil, + }, + { + name: "one group", + in: `package foo + +import ( + "fmt" + "testing" +) +`, + want: [][]string{{"fmt", "testing"}}, + }, + { + name: "four groups", + in: `package foo + +import "C" +import ( + "fmt" + "testing" + + "appengine" + + "myproject/mylib1" + "myproject/mylib2" +) +`, + want: [][]string{ + {"C"}, + {"fmt", "testing"}, + {"appengine"}, + {"myproject/mylib1", "myproject/mylib2"}, + }, + }, + { + name: "multiple factored groups", + in: `package foo + +import ( + "fmt" + "testing" + + "appengine" +) +import ( + "reflect" + + "bytes" +) +`, + want: [][]string{ + {"fmt", "testing"}, + {"appengine"}, + {"reflect"}, + {"bytes"}, + }, + }, +} + +func unquote(s string) string { + res, err := strconv.Unquote(s) + if err != nil { + return "could_not_unquote" + } + return res +} + +func TestImports(t *testing.T) { + fset := token.NewFileSet() + for _, test := range importsTests { + f, err := parser.ParseFile(fset, "test.go", test.in, 0) + if err != nil { + t.Errorf("%s: %v", test.name, err) + continue + } + var got [][]string + for _, block := range Imports(fset, f) { + var b []string + for _, spec := range block { + b = append(b, unquote(spec.Path.Value)) + } + got = append(got, b) + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("Imports(%s)=%v, want %v", test.name, got, test.want) + } + } +}