diff --git a/cmd/goimports/goimports.go b/cmd/goimports/goimports.go index c48cadbb..a3448547 100644 --- a/cmd/goimports/goimports.go +++ b/cmd/goimports/goimports.go @@ -30,6 +30,7 @@ var ( TabWidth: 8, TabIndent: true, Comments: true, + Fragment: true, } exitCode = 0 ) diff --git a/imports/fix_test.go b/imports/fix_test.go index f96f8535..d4646384 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -521,6 +521,25 @@ import str "strings" var _ = str.HasPrefix `, }, + + { + name: "fragment with main", + in: `func main(){fmt.Println("Hello, world")}`, + out: `package main + +import "fmt" + +func main() { fmt.Println("Hello, world") } +`, + }, + + { + name: "fragment without main", + in: `func notmain(){fmt.Println("Hello, world")}`, + out: `import "fmt" + +func notmain() { fmt.Println("Hello, world") }`, + }, } func TestFixImports(t *testing.T) { @@ -539,11 +558,18 @@ func TestFixImports(t *testing.T) { return simplePkgs[pkgName], pkgName == "str", nil } + options := &Options{ + TabWidth: 8, + TabIndent: true, + Comments: true, + Fragment: true, + } + for _, tt := range tests { if *only != "" && tt.name != *only { continue } - buf, err := Process(tt.name+".go", []byte(tt.in), nil) + buf, err := Process(tt.name+".go", []byte(tt.in), options) if err != nil { t.Errorf("error on %q: %v", tt.name, err) continue diff --git a/imports/imports.go b/imports/imports.go index 67869eff..dd1bc4d6 100644 --- a/imports/imports.go +++ b/imports/imports.go @@ -119,13 +119,19 @@ func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast // by inserting a package clause. // Insert using a ;, not a newline, so that the line numbers // in psrc match the ones in src. - psrc := append([]byte("package p;"), src...) + psrc := append([]byte("package main;"), src...) file, err = parser.ParseFile(fset, filename, psrc, parserMode) if err == nil { + // If a main function exists, we will assume this is a main + // package and leave the file. + if containsMainFunc(file) { + return file, nil, nil + } + adjust := func(orig, src []byte) []byte { // Remove the package clause. // Gofmt has turned the ; into a \n. - src = src[len("package p\n"):] + src = src[len("package main\n"):] return matchSpace(orig, src) } return file, adjust, nil @@ -162,6 +168,30 @@ func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast return nil, nil, err } +// containsMainFunc checks if a file contains a function declaration with the +// function signature 'func main()' +func containsMainFunc(file *ast.File) bool { + for _, decl := range file.Decls { + if f, ok := decl.(*ast.FuncDecl); ok { + if f.Name.Name != "main" { + continue + } + + if len(f.Type.Params.List) != 0 { + continue + } + + if f.Type.Results != nil && len(f.Type.Results.List) != 0 { + continue + } + + return true + } + } + + return false +} + func cutSpace(b []byte) (before, middle, after []byte) { i := 0 for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {