diff --git a/imports/fix_test.go b/imports/fix_test.go index 10e397d4..539f2cbb 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -936,6 +936,72 @@ func TestFixImports(t *testing.T) { } } +func TestProcess_nil_src(t *testing.T) { + dir, err := ioutil.TempDir("", "goimports-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + tests := []struct { + name string + in, out string + }{ + { + name: "nil-src", + in: `package foo +func bar() { +fmt.Println("hi") +} +`, + out: `package foo + +import "fmt" + +func bar() { + fmt.Println("hi") +} +`, + }, + { + name: "missing package", + in: ` +func bar() { +fmt.Println("hi") +} +`, + out: ` +import "fmt" + +func bar() { + fmt.Println("hi") +} +`, + }, + } + + options := &Options{ + TabWidth: 8, + TabIndent: true, + Comments: true, + Fragment: true, + } + + for _, tt := range tests { + filename := filepath.Join(dir, tt.name+".go") + if err := ioutil.WriteFile(filename, []byte(tt.in), 0666); err != nil { + t.Fatal(err) + } + buf, err := Process(filename, nil, options) + if err != nil { + t.Errorf("error on %q: %v", tt.name, err) + continue + } + if got := string(buf); got != tt.out { + t.Errorf("results diff on %q\nGOT:\n%s\nWANT:\n%s\n", tt.name, got, tt.out) + } + } +} + // Test support for packages in GOPATH that are actually symlinks. // Also test that a symlink loop does not block the process. func TestImportSymlinks(t *testing.T) { diff --git a/imports/imports.go b/imports/imports.go index 06232487..a4cbf5c7 100644 --- a/imports/imports.go +++ b/imports/imports.go @@ -18,6 +18,7 @@ import ( "go/printer" "go/token" "io" + "io/ioutil" "regexp" "strconv" "strings" @@ -47,6 +48,13 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) { if opt == nil { opt = &Options{Comments: true, TabIndent: true, TabWidth: 8} } + if src == nil { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + src = b + } fileSet := token.NewFileSet() file, adjust, err := parse(fileSet, filename, src, opt)