diff --git a/refactor/rename/rename.go b/refactor/rename/rename.go index 53d88801..12fc149f 100644 --- a/refactor/rename/rename.go +++ b/refactor/rename/rename.go @@ -304,8 +304,21 @@ func plural(n int) string { return "" } +func writeFile(name string, fset *token.FileSet, f *ast.File) error { + out, err := os.Create(name) + if err != nil { + // assume error includes the filename + return fmt.Errorf("failed to open file: %s", err) + } + if err := format.Node(out, fset, f); err != nil { + out.Close() // ignore error + return fmt.Errorf("failed to write file: %s", err) + } + return out.Close() +} + var rewriteFile = func(fset *token.FileSet, f *ast.File, orig string) (err error) { - backup := orig + ".prename" + backup := orig + ".gorename.backup" // TODO(adonovan): print packages and filenames in a form useful // to editors (so they can reload files). if Verbose { @@ -315,18 +328,12 @@ var rewriteFile = func(fset *token.FileSet, f *ast.File, orig string) (err error return fmt.Errorf("failed to make backup %s -> %s: %s", orig, filepath.Base(backup), err) } - out, err := os.Create(orig) - if err != nil { - // assume error includes the filename - return fmt.Errorf("failed to open file: %s", err) - } - defer func() { - if closeErr := out.Close(); err == nil { - err = closeErr // don't clobber existing error - } - }() - if err := format.Node(out, fset, f); err != nil { - return fmt.Errorf("failed to write file: %s", err) + if err := writeFile(orig, fset, f); err != nil { + // Restore the file from the backup. + os.Remove(orig) // ignore error + os.Rename(backup, orig) // ignore error + return err } + os.Remove(backup) // ignore error return nil }