diff --git a/cmd/eg/eg.go b/cmd/eg/eg.go new file mode 100644 index 00000000..3a12d33d --- /dev/null +++ b/cmd/eg/eg.go @@ -0,0 +1,121 @@ +// The eg command performs example-based refactoring. +package main + +import ( + "flag" + "fmt" + "go/parser" + "go/printer" + "go/token" + "os" + "path/filepath" + + "code.google.com/p/go.tools/go/loader" + "code.google.com/p/go.tools/refactor/eg" +) + +var ( + helpFlag = flag.Bool("help", false, "show detailed help message") + templateFlag = flag.String("t", "", "template.go file specifying the refactoring") + transitiveFlag = flag.Bool("transitive", false, "apply refactoring to all dependencies too") + writeFlag = flag.Bool("w", false, "rewrite input files in place (by default, the results are printed to standard output)") + verboseFlag = flag.Bool("v", false, "show verbose matcher diagnostics") +) + +const usage = `eg: an example-based refactoring tool. + +Usage: eg -t template.go [-w] [-transitive] ... +-t template.go specifies the template file (use -help to see explanation) +-w causes files to be re-written in place. +-transitive causes all dependencies to be refactored too. +` + loader.FromArgsUsage + +func main() { + if err := doMain(); err != nil { + fmt.Fprintf(os.Stderr, "%s: %s.\n", filepath.Base(os.Args[0]), err) + os.Exit(1) + } +} + +func doMain() error { + flag.Parse() + args := flag.Args() + + if *helpFlag { + fmt.Fprintf(os.Stderr, eg.Help) + os.Exit(2) + } + + if *templateFlag == "" { + return fmt.Errorf("no -t template.go file specified") + } + + conf := loader.Config{ + Fset: token.NewFileSet(), + ParserMode: parser.ParseComments, + SourceImports: true, + } + + // The first Created package is the template. + if err := conf.CreateFromFilenames("template", *templateFlag); err != nil { + return err // e.g. "foo.go:1: syntax error" + } + + if len(args) == 0 { + fmt.Fprint(os.Stderr, usage) + os.Exit(1) + } + + if _, err := conf.FromArgs(args, true); err != nil { + return err + } + + // Load, parse and type-check the whole program. + iprog, err := conf.Load() + if err != nil { + return err + } + + // Analyze the template. + template := iprog.Created[0] + xform, err := eg.NewTransformer(iprog.Fset, template, *verboseFlag) + if err != nil { + return err + } + + // Apply it to the input packages. + var pkgs []*loader.PackageInfo + if *transitiveFlag { + for _, info := range iprog.AllPackages { + pkgs = append(pkgs, info) + } + } else { + pkgs = iprog.InitialPackages() + } + var hadErrors bool + for _, pkg := range pkgs { + if pkg == template { + continue + } + for _, file := range pkg.Files { + n := xform.Transform(&pkg.Info, pkg.Pkg, file) + if n == 0 { + continue + } + filename := iprog.Fset.File(file.Pos()).Name() + fmt.Fprintf(os.Stderr, "=== %s (%d matches):\n", filename, n) + if *writeFlag { + if err := eg.WriteAST(iprog.Fset, filename, file); err != nil { + fmt.Fprintf(os.Stderr, "Error: %s\n", err) + hadErrors = true + } + } else { + printer.Fprint(os.Stdout, iprog.Fset, file) + } + } + } + if hadErrors { + os.Exit(1) + } + return nil +} diff --git a/refactor/README b/refactor/README new file mode 100644 index 00000000..a6edb135 --- /dev/null +++ b/refactor/README @@ -0,0 +1 @@ +code.google.com/p/go.tools/refactor: libraries for refactoring tools. diff --git a/refactor/eg/eg.go b/refactor/eg/eg.go new file mode 100644 index 00000000..5afe635e --- /dev/null +++ b/refactor/eg/eg.go @@ -0,0 +1,326 @@ +// Package eg implements the example-based refactoring tool whose +// command-line is defined in code.google.com/p/go.tools/cmd/eg. +package eg + +import ( + "bytes" + "fmt" + "go/ast" + "go/printer" + "go/token" + "os" + + "code.google.com/p/go.tools/go/loader" + "code.google.com/p/go.tools/go/types" +) + +const Help = ` +This tool implements example-based refactoring of expressions. + +The transformation is specified as a Go file defining two functions, +'before' and 'after', of identical types. Each function body consists +of a single statement: either a return statement with a single +(possibly multi-valued) expression, or an expression statement. The +'before' expression specifies a pattern and the 'after' expression its +replacement. + + package P + import ( "errors"; "fmt" ) + func before(s string) error { return fmt.Errorf("%s", s) } + func after(s string) error { return errors.New(s) } + +The expression statement form is useful when the the expression has no +result, for example: + + func before(msg string) { log.Fatalf("%s", msg) } + func after(msg string) { log.Fatal(msg) } + +The parameters of both functions are wildcards that may match any +expression assignable to that type. If the pattern contains multiple +occurrences of the same parameter, each must match the same expression +in the input for the pattern to match. If the replacement contains +multiple occurrences of the same parameter, the expression will be +duplicated, possibly changing the side-effects. + +The tool analyses all Go code in the packages specified by the +arguments, replacing all occurrences of the pattern with the +substitution. + +So, the transform above would change this input: + err := fmt.Errorf("%s", "error: " + msg) +to this output: + err := errors.New("error: " + msg) + +Identifiers, including qualified identifiers (p.X) are considered to +match only if they denote the same object. This allows correct +matching even in the presence of dot imports, named imports and +locally shadowed package names in the input program. + +Matching of type syntax is semantic, not syntactic: type syntax in the +pattern matches type syntax in the input if the types are identical. +Thus, func(x int) matches func(y int). + +This tool was inspired by other example-based refactoring tools, +'gofmt -r' for Go and Refaster for Java. + + +LIMITATIONS +=========== + +EXPRESSIVENESS + +Only refactorings that replace one expression with another, regardless +of the expression's context, may be expressed. Refactoring arbitrary +statements (or sequences of statements) is a less well-defined problem +and is less amenable to this approach. + +A pattern that contains a function literal (and hence statements) +never matches. + +There is no way to generalize over related types, e.g. to express that +a wildcard may have any integer type, for example. + + +SAFETY + +Verifying that a transformation does not introduce type errors is very +complex in the general case. An innocuous-looking replacement of one +constant by another (e.g. 1 to 2) may cause type errors relating to +array types and indices, for example. The tool performs only very +superficial checks of type preservation. + +It is not possible to replace an expression by one of a different +type, even in contexts where this is legal, such as x in fmt.Print(x). + + +IMPORTS + +Although the matching algorithm is fully aware of scoping rules, the +replacement algorithm is not, so the replacement code may contain +incorrect identifier syntax for imported objects if there are dot +imports, named imports or locally shadowed package names in the input +program. + +Imports are added as needed, but they are not removed as needed. +Run 'goimports' on the modified file for now. + +Dot imports are forbidden in the template. +` + +// TODO(adonovan): allow the tool to be invoked using relative package +// directory names (./foo). Requires changes to go/loader. + +// TODO(adonovan): expand upon the above documentation as an HTML page. + +// TODO(adonovan): eliminate dependency on loader.PackageInfo. +// Move its ObjectOf/IsType/TypeOf methods into go/types. + +// A Transformer represents a single example-based transformation. +type Transformer struct { + fset *token.FileSet + verbose bool + info loader.PackageInfo // combined type info for template/input/output ASTs + seenInfos map[*types.Info]bool + wildcards map[*types.Var]bool // set of parameters in func before() + env map[string]ast.Expr // maps parameter name to wildcard binding + importedObjs map[types.Object]*ast.SelectorExpr // objects imported by after(). + before, after ast.Expr + allowWildcards bool + + // Working state of Transform(): + nsubsts int // number of substitutions made + currentPkg *types.Package // package of current call +} + +// NewTransformer returns a transformer based on the specified template, +// a package containing "before" and "after" functions as described +// in the package documentation. +// +func NewTransformer(fset *token.FileSet, template *loader.PackageInfo, verbose bool) (*Transformer, error) { + // Check the template. + beforeSig := funcSig(template.Pkg, "before") + if beforeSig == nil { + return nil, fmt.Errorf("no 'before' func found in template") + } + afterSig := funcSig(template.Pkg, "after") + if afterSig == nil { + return nil, fmt.Errorf("no 'after' func found in template") + } + + // TODO(adonovan): should we also check the names of the params match? + if !types.Identical(afterSig, beforeSig) { + return nil, fmt.Errorf("before %s and after %s functions have different signatures", + beforeSig, afterSig) + } + + templateFile := template.Files[0] + for _, imp := range templateFile.Imports { + if imp.Name != nil && imp.Name.Name == "." { + // Dot imports are currently forbidden. We + // make the simplifying assumption that all + // imports are regular, without local renames. + //TODO document + return nil, fmt.Errorf("dot-import (of %s) in template", imp.Path.Value) + } + } + var beforeDecl, afterDecl *ast.FuncDecl + for _, decl := range templateFile.Decls { + if decl, ok := decl.(*ast.FuncDecl); ok { + switch decl.Name.Name { + case "before": + beforeDecl = decl + case "after": + afterDecl = decl + } + } + } + + before, err := soleExpr(beforeDecl) + if err != nil { + return nil, fmt.Errorf("before: %s", err) + } + after, err := soleExpr(afterDecl) + if err != nil { + return nil, fmt.Errorf("after: %s", err) + } + + wildcards := make(map[*types.Var]bool) + for i := 0; i < beforeSig.Params().Len(); i++ { + wildcards[beforeSig.Params().At(i)] = true + } + + // checkExprTypes returns an error if Tb (type of before()) is not + // safe to replace with Ta (type of after()). + // + // Only superficial checks are performed, and they may result in both + // false positives and negatives. + // + // Ideally, we would only require that the replacement be assignable + // to the context of a specific pattern occurrence, but the type + // checker doesn't record that information and it's complex to deduce. + // A Go type cannot capture all the constraints of a given expression + // context, which may include the size, constness, signedness, + // namedness or constructor of its type, and even the specific value + // of the replacement. (Consider the rule that array literal keys + // must be unique.) So we cannot hope to prove the safety of a + // transformation in general. + Tb := template.TypeOf(before) + Ta := template.TypeOf(after) + if types.AssignableTo(Tb, Ta) { + // safe: replacement is assignable to pattern. + } else if tuple, ok := Tb.(*types.Tuple); ok && tuple.Len() == 0 { + // safe: pattern has void type (must appear in an ExprStmt). + } else { + return nil, fmt.Errorf("%s is not a safe replacement for %s", Ta, Tb) + } + + tr := &Transformer{ + fset: fset, + verbose: verbose, + wildcards: wildcards, + allowWildcards: true, + seenInfos: make(map[*types.Info]bool), + importedObjs: make(map[types.Object]*ast.SelectorExpr), + before: before, + after: after, + } + + // Combine type info from the template and input packages, and + // type info for the synthesized ASTs too. This saves us + // having to book-keep where each ast.Node originated as we + // construct the resulting hybrid AST. + // + // TODO(adonovan): move type utility methods of PackageInfo to + // types.Info, or at least into go/types.typeutil. + tr.info.Info = types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + mergeTypeInfo(&tr.info.Info, &template.Info) + + // Compute set of imported objects required by after(). + // TODO reject dot-imports in pattern + ast.Inspect(after, func(n ast.Node) bool { + if n, ok := n.(*ast.SelectorExpr); ok { + sel := tr.info.Selections[n] + if sel.Kind() == types.PackageObj { + tr.importedObjs[sel.Obj()] = n + return false // prune + } + } + return true // recur + }) + + return tr, nil +} + +// WriteAST is a convenience function that writes AST f to the specified file. +func WriteAST(fset *token.FileSet, filename string, f *ast.File) (err error) { + fh, err := os.Create(filename) + if err != nil { + return err + } + defer func() { + if err2 := fh.Close(); err != nil { + err = err2 // prefer earlier error + } + }() + return printer.Fprint(fh, fset, f) +} + +// -- utilities -------------------------------------------------------- + +// funcSig returns the signature of the specified package-level function. +func funcSig(pkg *types.Package, name string) *types.Signature { + if f, ok := pkg.Scope().Lookup(name).(*types.Func); ok { + return f.Type().(*types.Signature) + } + return nil +} + +// soleExpr returns the sole expression in the before/after template function. +func soleExpr(fn *ast.FuncDecl) (ast.Expr, error) { + if fn.Body == nil { + return nil, fmt.Errorf("no body") + } + if len(fn.Body.List) != 1 { + return nil, fmt.Errorf("must contain a single statement") + } + switch stmt := fn.Body.List[0].(type) { + case *ast.ReturnStmt: + if len(stmt.Results) != 1 { + return nil, fmt.Errorf("return statement must have a single operand") + } + return stmt.Results[0], nil + + case *ast.ExprStmt: + return stmt.X, nil + } + + return nil, fmt.Errorf("must contain a single return or expression statement") +} + +// mergeTypeInfo adds type info from src to dst. +func mergeTypeInfo(dst, src *types.Info) { + for k, v := range src.Types { + dst.Types[k] = v + } + for k, v := range src.Defs { + dst.Defs[k] = v + } + for k, v := range src.Uses { + dst.Uses[k] = v + } + for k, v := range src.Selections { + dst.Selections[k] = v + } +} + +// (debugging only) +func astString(fset *token.FileSet, n ast.Node) string { + var buf bytes.Buffer + printer.Fprint(&buf, fset, n) + return buf.String() +} diff --git a/refactor/eg/eg_test.go b/refactor/eg/eg_test.go new file mode 100644 index 00000000..38174029 --- /dev/null +++ b/refactor/eg/eg_test.go @@ -0,0 +1,145 @@ +package eg_test + +import ( + "bytes" + "flag" + "go/parser" + "go/token" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + + "code.google.com/p/go.tools/go/exact" + "code.google.com/p/go.tools/go/loader" + "code.google.com/p/go.tools/go/types" + "code.google.com/p/go.tools/refactor/eg" +) + +// TODO(adonovan): more tests: +// - of command-line tool +// - of all parts of syntax +// - of applying a template to a package it imports: +// the replacement syntax should use unqualified names for its objects. + +var ( + updateFlag = flag.Bool("update", false, "update the golden files") + verboseFlag = flag.Bool("verbose", false, "show matcher information") +) + +func Test(t *testing.T) { + switch runtime.GOOS { + case "windows": + t.Skipf("skipping test on %q (no /usr/bin/diff)", runtime.GOOS) + } + + conf := loader.Config{ + Fset: token.NewFileSet(), + ParserMode: parser.ParseComments, + SourceImports: true, + } + + // Each entry is a single-file package. + // (Multi-file packages aren't interesting for this test.) + // Order matters: each non-template package is processed using + // the preceding template package. + for _, filename := range []string{ + "testdata/A.template", + "testdata/A1.go", + "testdata/A2.go", + + "testdata/B.template", + "testdata/B1.go", + + "testdata/C.template", + "testdata/C1.go", + + "testdata/D.template", + "testdata/D1.go", + + "testdata/E.template", + "testdata/E1.go", + + "testdata/bad_type.template", + "testdata/no_before.template", + "testdata/no_after_return.template", + "testdata/type_mismatch.template", + "testdata/expr_type_mismatch.template", + } { + pkgname := strings.TrimSuffix(filepath.Base(filename), ".go") + if err := conf.CreateFromFilenames(pkgname, filename); err != nil { + t.Fatal(err) + } + } + iprog, err := conf.Load() + if err != nil { + t.Fatal(err) + } + + var xform *eg.Transformer + for _, info := range iprog.Created { + file := info.Files[0] + filename := iprog.Fset.File(file.Pos()).Name() // foo.go + + if strings.HasSuffix(filename, "template") { + // a new template + shouldFail, _ := info.Pkg.Scope().Lookup("shouldFail").(*types.Const) + xform, err = eg.NewTransformer(iprog.Fset, info, *verboseFlag) + if err != nil { + if shouldFail == nil { + t.Errorf("NewTransformer(%s): %s", filename, err) + } else if want := exact.StringVal(shouldFail.Val()); !strings.Contains(err.Error(), want) { + t.Errorf("NewTransformer(%s): got error %q, want error %q", filename, err, want) + } + } else if shouldFail != nil { + t.Errorf("NewTransformer(%s) succeeded unexpectedly; want error %q", + filename, shouldFail.Val()) + } + continue + } + + if xform == nil { + t.Errorf("%s: no previous template", filename) + continue + } + + // apply previous template to this package + n := xform.Transform(&info.Info, info.Pkg, file) + if n == 0 { + t.Errorf("%s: no matches", filename) + continue + } + + got := filename + "t" // foo.got + golden := filename + "lden" // foo.golden + + // Write actual output to foo.got. + if err := eg.WriteAST(iprog.Fset, got, file); err != nil { + t.Error(err) + } + + // Compare foo.got with foo.golden. + var cmd *exec.Cmd + switch runtime.GOOS { + case "plan9": + cmd = exec.Command("/bin/diff", "-c", golden, got) + default: + cmd = exec.Command("/usr/bin/diff", "-u", "-N", golden, got) + } + buf := new(bytes.Buffer) + cmd.Stdout = buf + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Errorf("eg tests for %s failed: %s.\n%s\n", filename, err, buf) + + if *updateFlag { + t.Logf("Updating %s...", golden) + if err := exec.Command("/bin/cp", got, golden).Run(); err != nil { + t.Errorf("Update failed: %s", err) + } + } + } + } +} diff --git a/refactor/eg/match.go b/refactor/eg/match.go new file mode 100644 index 00000000..7476f9ac --- /dev/null +++ b/refactor/eg/match.go @@ -0,0 +1,226 @@ +package eg + +import ( + "fmt" + "go/ast" + "go/token" + "log" + "os" + "reflect" + + "code.google.com/p/go.tools/go/exact" + "code.google.com/p/go.tools/go/loader" + "code.google.com/p/go.tools/go/types" +) + +// matchExpr reports whether pattern x matches y. +// +// If tr.allowWildcards, Idents in x that refer to parameters are +// treated as wildcards, and match any y that is assignable to the +// parameter type; matchExpr records this correspondence in tr.env. +// Otherwise, matchExpr simply reports whether the two trees are +// equivalent. +// +// A wildcard appearing more than once in the pattern must +// consistently match the same tree. +// +func (tr *Transformer) matchExpr(x, y ast.Expr) bool { + if x == nil && y == nil { + return true + } + if x == nil || y == nil { + return false + } + x = unparen(x) + y = unparen(y) + + // Is x a wildcard? (a reference to a 'before' parameter) + if x, ok := x.(*ast.Ident); ok && x != nil && tr.allowWildcards { + if xobj, ok := tr.info.Uses[x].(*types.Var); ok && tr.wildcards[xobj] { + return tr.matchWildcard(xobj, y) + } + } + + // Object identifiers (including pkg-qualified ones) + // are handled semantically, not syntactically. + xobj := isRef(x, &tr.info) + yobj := isRef(y, &tr.info) + if xobj != nil { + return xobj == yobj + } + if yobj != nil { + return false + } + + // TODO(adonovan): audit: we cannot assume these ast.Exprs + // contain non-nil pointers. e.g. ImportSpec.Name may be a + // nil *ast.Ident. + + if reflect.TypeOf(x) != reflect.TypeOf(y) { + return false + } + switch x := x.(type) { + case *ast.Ident: + log.Fatalf("unexpected Ident: %s", astString(tr.fset, x)) + + case *ast.BasicLit: + y := y.(*ast.BasicLit) + xval := exact.MakeFromLiteral(x.Value, x.Kind) + yval := exact.MakeFromLiteral(y.Value, y.Kind) + return exact.Compare(xval, token.EQL, yval) + + case *ast.FuncLit: + // func literals (and thus statement syntax) never match. + return false + + case *ast.CompositeLit: + y := y.(*ast.CompositeLit) + return (x.Type == nil) == (y.Type == nil) && + (x.Type == nil || tr.matchType(x.Type, y.Type)) && + tr.matchExprs(x.Elts, y.Elts) + + case *ast.SelectorExpr: + y := y.(*ast.SelectorExpr) + return tr.matchExpr(x.X, y.X) && + tr.info.Selections[x].Obj() == tr.info.Selections[y].Obj() + + case *ast.IndexExpr: + y := y.(*ast.IndexExpr) + return tr.matchExpr(x.X, y.X) && + tr.matchExpr(x.Index, y.Index) + + case *ast.SliceExpr: + y := y.(*ast.SliceExpr) + return tr.matchExpr(x.X, y.X) && + tr.matchExpr(x.Low, y.Low) && + tr.matchExpr(x.High, y.High) && + tr.matchExpr(x.Max, y.Max) && + x.Slice3 == y.Slice3 + + case *ast.TypeAssertExpr: + y := y.(*ast.TypeAssertExpr) + return tr.matchExpr(x.X, y.X) && + tr.matchType(x.Type, y.Type) + + case *ast.CallExpr: + y := y.(*ast.CallExpr) + match := tr.matchExpr // function call + if tr.info.IsType(x.Fun) { + match = tr.matchType // type conversion + } + return x.Ellipsis.IsValid() == y.Ellipsis.IsValid() && + match(x.Fun, y.Fun) && + tr.matchExprs(x.Args, y.Args) + + case *ast.StarExpr: + y := y.(*ast.StarExpr) + return tr.matchExpr(x.X, y.X) + + case *ast.UnaryExpr: + y := y.(*ast.UnaryExpr) + return x.Op == y.Op && + tr.matchExpr(x.X, y.X) + + case *ast.BinaryExpr: + y := y.(*ast.BinaryExpr) + return x.Op == y.Op && + tr.matchExpr(x.X, y.X) && + tr.matchExpr(x.Y, y.Y) + + case *ast.KeyValueExpr: + y := y.(*ast.KeyValueExpr) + return tr.matchExpr(x.Key, y.Key) && + tr.matchExpr(x.Value, y.Value) + } + + panic(fmt.Sprintf("unhandled AST node type: %T", x)) +} + +func (tr *Transformer) matchExprs(xx, yy []ast.Expr) bool { + if len(xx) != len(yy) { + return false + } + for i := range xx { + if !tr.matchExpr(xx[i], yy[i]) { + return false + } + } + return true +} + +// matchType reports whether the two type ASTs denote identical types. +func (tr *Transformer) matchType(x, y ast.Expr) bool { + tx := tr.info.Types[x].Type + ty := tr.info.Types[y].Type + return types.Identical(tx, ty) +} + +func (tr *Transformer) matchWildcard(xobj *types.Var, y ast.Expr) bool { + name := xobj.Name() + + if tr.verbose { + fmt.Fprintf(os.Stderr, "%s: wildcard %s -> %s?: ", + tr.fset.Position(y.Pos()), name, astString(tr.fset, y)) + } + + // Check that y is assignable to the declared type of the param. + if yt := tr.info.TypeOf(y); !types.AssignableTo(yt, xobj.Type()) { + if tr.verbose { + fmt.Fprintf(os.Stderr, "%s not assignable to %s\n", yt, xobj.Type()) + } + return false + } + + // A wildcard matches any expression. + // If it appears multiple times in the pattern, it must match + // the same expression each time. + if old, ok := tr.env[name]; ok { + // found existing binding + tr.allowWildcards = false + r := tr.matchExpr(old, y) + if tr.verbose { + fmt.Fprintf(os.Stderr, "%t secondary match, primary was %s\n", + r, astString(tr.fset, old)) + } + tr.allowWildcards = true + return r + } + + if tr.verbose { + fmt.Fprintf(os.Stderr, "primary match\n") + } + + tr.env[name] = y // record binding + return true +} + +// -- utilities -------------------------------------------------------- + +// unparen returns e with any enclosing parentheses stripped. +// TODO(adonovan): move to astutil package. +func unparen(e ast.Expr) ast.Expr { + for { + p, ok := e.(*ast.ParenExpr) + if !ok { + break + } + e = p.X + } + return e +} + +// isRef returns the object referred to by this (possibly qualified) +// identifier, or nil if the node is not a referring identifier. +func isRef(n ast.Node, info *loader.PackageInfo) types.Object { + switch n := n.(type) { + case *ast.Ident: + return info.Uses[n] + + case *ast.SelectorExpr: + sel := info.Selections[n] + if sel.Kind() == types.PackageObj { + return sel.Obj() + } + } + return nil +} diff --git a/refactor/eg/rewrite.go b/refactor/eg/rewrite.go new file mode 100644 index 00000000..92921c43 --- /dev/null +++ b/refactor/eg/rewrite.go @@ -0,0 +1,347 @@ +package eg + +// This file defines the AST rewriting pass. +// Most of it was plundered directly from +// $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution). + +import ( + "fmt" + "go/ast" + "go/token" + "os" + "reflect" + "sort" + "strconv" + "strings" + + "code.google.com/p/go.tools/astutil" + "code.google.com/p/go.tools/go/types" +) + +// Transform applies the transformation to the specified parsed file, +// whose type information is supplied in info, and returns the number +// of replacements that were made. +// +// It mutates the AST in place (the identity of the root node is +// unchanged), and may add nodes for which no type information is +// available in info. +// +// Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go. +// +func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int { + if !tr.seenInfos[info] { + tr.seenInfos[info] = true + mergeTypeInfo(&tr.info.Info, info) + } + tr.currentPkg = pkg + tr.nsubsts = 0 + + if tr.verbose { + fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before)) + fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after)) + } + + var f func(rv reflect.Value) reflect.Value + f = func(rv reflect.Value) reflect.Value { + // don't bother if val is invalid to start with + if !rv.IsValid() { + return reflect.Value{} + } + + rv = apply(f, rv) + + e := rvToExpr(rv) + if e != nil { + savedEnv := tr.env + tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs + + if tr.matchExpr(tr.before, e) { + if tr.verbose { + fmt.Fprintf(os.Stderr, "%s matches %s", + astString(tr.fset, tr.before), astString(tr.fset, e)) + if len(tr.env) > 0 { + fmt.Fprintf(os.Stderr, " with:") + for name, ast := range tr.env { + fmt.Fprintf(os.Stderr, " %s->%s", + name, astString(tr.fset, ast)) + } + } + fmt.Fprintf(os.Stderr, "\n") + } + tr.nsubsts++ + + // Clone the replacement tree, performing parameter substitution. + // We update all positions to n.Pos() to aid comment placement. + rv = tr.subst(tr.env, reflect.ValueOf(tr.after), + reflect.ValueOf(e.Pos())) + } + tr.env = savedEnv + } + + return rv + } + file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File) + + // By construction, the root node is unchanged. + if file != file2 { + panic("BUG") + } + + // Add any necessary imports. + // TODO(adonovan): remove no-longer needed imports too. + if tr.nsubsts > 0 { + pkgs := make(map[string]*types.Package) + for obj := range tr.importedObjs { + pkgs[obj.Pkg().Path()] = obj.Pkg() + } + + for _, imp := range file.Imports { + path, _ := strconv.Unquote(imp.Path.Value) + delete(pkgs, path) + } + delete(pkgs, pkg.Path()) // don't import self + + // NB: AddImport may completely replace the AST! + // It thus renders info and tr.info no longer relevant to file. + var paths []string + for path := range pkgs { + paths = append(paths, path) + } + sort.Strings(paths) + for _, path := range paths { + astutil.AddImport(tr.fset, file, path) + } + } + + tr.currentPkg = nil + + return tr.nsubsts +} + +// setValue is a wrapper for x.SetValue(y); it protects +// the caller from panics if x cannot be changed to y. +func setValue(x, y reflect.Value) { + // don't bother if y is invalid to start with + if !y.IsValid() { + return + } + defer func() { + if x := recover(); x != nil { + if s, ok := x.(string); ok && + (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) { + // x cannot be set to y - ignore this rewrite + return + } + panic(x) + } + }() + x.Set(y) +} + +// Values/types for special cases. +var ( + objectPtrNil = reflect.ValueOf((*ast.Object)(nil)) + scopePtrNil = reflect.ValueOf((*ast.Scope)(nil)) + + identType = reflect.TypeOf((*ast.Ident)(nil)) + selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil)) + objectPtrType = reflect.TypeOf((*ast.Object)(nil)) + positionType = reflect.TypeOf(token.NoPos) + callExprType = reflect.TypeOf((*ast.CallExpr)(nil)) + scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) +) + +// apply replaces each AST field x in val with f(x), returning val. +// To avoid extra conversions, f operates on the reflect.Value form. +func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value { + if !val.IsValid() { + return reflect.Value{} + } + + // *ast.Objects introduce cycles and are likely incorrect after + // rewrite; don't follow them but replace with nil instead + if val.Type() == objectPtrType { + return objectPtrNil + } + + // similarly for scopes: they are likely incorrect after a rewrite; + // replace them with nil + if val.Type() == scopePtrType { + return scopePtrNil + } + + switch v := reflect.Indirect(val); v.Kind() { + case reflect.Slice: + for i := 0; i < v.Len(); i++ { + e := v.Index(i) + setValue(e, f(e)) + } + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + e := v.Field(i) + setValue(e, f(e)) + } + case reflect.Interface: + e := v.Elem() + setValue(v, f(e)) + } + return val +} + +// subst returns a copy of (replacement) pattern with values from env +// substituted in place of wildcards and pos used as the position of +// tokens from the pattern. if env == nil, subst returns a copy of +// pattern and doesn't change the line number information. +func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value { + if !pattern.IsValid() { + return reflect.Value{} + } + + // *ast.Objects introduce cycles and are likely incorrect after + // rewrite; don't follow them but replace with nil instead + if pattern.Type() == objectPtrType { + return objectPtrNil + } + + // similarly for scopes: they are likely incorrect after a rewrite; + // replace them with nil + if pattern.Type() == scopePtrType { + return scopePtrNil + } + + // Wildcard gets replaced with map value. + if env != nil && pattern.Type() == identType { + id := pattern.Interface().(*ast.Ident) + if old, ok := env[id.Name]; ok { + return tr.subst(nil, reflect.ValueOf(old), reflect.Value{}) + } + } + + // Emit qualified identifiers in the pattern by appropriate + // (possibly qualified) identifier in the input. + // + // The template cannot contain dot imports, so all identifiers + // for imported objects are explicitly qualified. + // + // We assume (unsoundly) that there are no dot or named + // imports in the input code, nor are any imported package + // names shadowed, so the usual normal qualified identifier + // syntax may be used. + // TODO(adonovan): fix: avoid this assumption. + // + // A refactoring may be applied to a package referenced by the + // template. Objects belonging to the current package are + // denoted by unqualified identifiers. + // + if tr.importedObjs != nil && pattern.Type() == selectorExprType { + obj := isRef(pattern.Interface().(*ast.SelectorExpr), &tr.info) + if obj != nil { + if sel, ok := tr.importedObjs[obj]; ok { + var id ast.Expr + if obj.Pkg() == tr.currentPkg { + id = sel.Sel // unqualified + } else { + id = sel // pkg-qualified + } + + // Return a clone of id. + saved := tr.importedObjs + tr.importedObjs = nil // break cycle + r := tr.subst(nil, reflect.ValueOf(id), pos) + tr.importedObjs = saved + return r + } + } + } + + if pos.IsValid() && pattern.Type() == positionType { + // use new position only if old position was valid in the first place + if old := pattern.Interface().(token.Pos); !old.IsValid() { + return pattern + } + return pos + } + + // Otherwise copy. + switch p := pattern; p.Kind() { + case reflect.Slice: + v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) + for i := 0; i < p.Len(); i++ { + v.Index(i).Set(tr.subst(env, p.Index(i), pos)) + } + return v + + case reflect.Struct: + v := reflect.New(p.Type()).Elem() + for i := 0; i < p.NumField(); i++ { + v.Field(i).Set(tr.subst(env, p.Field(i), pos)) + } + return v + + case reflect.Ptr: + v := reflect.New(p.Type()).Elem() + if elem := p.Elem(); elem.IsValid() { + v.Set(tr.subst(env, elem, pos).Addr()) + } + + // Duplicate type information for duplicated ast.Expr. + // All ast.Node implementations are *structs, + // so this case catches them all. + if e := rvToExpr(v); e != nil { + updateTypeInfo(&tr.info.Info, e, p.Interface().(ast.Expr)) + } + return v + + case reflect.Interface: + v := reflect.New(p.Type()).Elem() + if elem := p.Elem(); elem.IsValid() { + v.Set(tr.subst(env, elem, pos)) + } + return v + } + + return pattern +} + +// -- utilitiies ------------------------------------------------------- + +func rvToExpr(rv reflect.Value) ast.Expr { + if rv.CanInterface() { + if e, ok := rv.Interface().(ast.Expr); ok { + return e + } + } + return nil +} + +// updateTypeInfo duplicates type information for the existing AST old +// so that it also applies to duplicated AST new. +func updateTypeInfo(info *types.Info, new, old ast.Expr) { + switch new := new.(type) { + case *ast.Ident: + orig := old.(*ast.Ident) + if obj, ok := info.Defs[orig]; ok { + info.Defs[new] = obj + } + if obj, ok := info.Uses[orig]; ok { + info.Uses[new] = obj + } + + case *ast.SelectorExpr: + orig := old.(*ast.SelectorExpr) + if sel, ok := info.Selections[orig]; ok { + info.Selections[new] = sel + } + } + + if tv, ok := info.Types[old]; ok { + info.Types[new] = tv + } +} + +func F() {} +func G() {} + +func init() { + F() +} diff --git a/refactor/eg/testdata/A.template b/refactor/eg/testdata/A.template new file mode 100644 index 00000000..f6119618 --- /dev/null +++ b/refactor/eg/testdata/A.template @@ -0,0 +1,13 @@ +// +build ignore + +package template + +// Basic test of type-aware expression refactoring. + +import ( + "errors" + "fmt" +) + +func before(s string) error { return fmt.Errorf("%s", s) } +func after(s string) error { return errors.New(s) } diff --git a/refactor/eg/testdata/A1.go b/refactor/eg/testdata/A1.go new file mode 100644 index 00000000..9e65eb36 --- /dev/null +++ b/refactor/eg/testdata/A1.go @@ -0,0 +1,51 @@ +// +build ignore + +package A1 + +import ( + . "fmt" + myfmt "fmt" + "os" + "strings" +) + +func example(n int) { + x := "foo" + strings.Repeat("\t", n) + // Match, despite named import. + myfmt.Errorf("%s", x) + + // Match, despite dot import. + Errorf("%s", x) + + // Match: multiple matches in same function are possible. + myfmt.Errorf("%s", x) + + // No match: wildcarded operand has the wrong type. + myfmt.Errorf("%s", 3) + + // No match: function operand doesn't match. + myfmt.Printf("%s", x) + + // No match again, dot import. + Printf("%s", x) + + // Match. + myfmt.Fprint(os.Stderr, myfmt.Errorf("%s", x+"foo")) + + // No match: though this literally matches the template, + // fmt doesn't resolve to a package here. + var fmt struct{ Errorf func(string, string) } + fmt.Errorf("%s", x) + + // Recursive matching: + + // Match: both matches are well-typed, so both succeed. + myfmt.Errorf("%s", myfmt.Errorf("%s", x+"foo").Error()) + + // Outer match succeeds, inner doesn't: 3 has wrong type. + myfmt.Errorf("%s", myfmt.Errorf("%s", 3).Error()) + + // Inner match succeeds, outer doesn't: the inner replacement + // has the wrong type (error not string). + myfmt.Errorf("%s", myfmt.Errorf("%s", x+"foo")) +} diff --git a/refactor/eg/testdata/A1.golden b/refactor/eg/testdata/A1.golden new file mode 100644 index 00000000..4f7ba828 --- /dev/null +++ b/refactor/eg/testdata/A1.golden @@ -0,0 +1,52 @@ +// +build ignore + +package A1 + +import ( + . "fmt" + "errors" + myfmt "fmt" + "os" + "strings" +) + +func example(n int) { + x := "foo" + strings.Repeat("\t", n) + // Match, despite named import. + errors.New(x) + + // Match, despite dot import. + errors.New(x) + + // Match: multiple matches in same function are possible. + errors.New(x) + + // No match: wildcarded operand has the wrong type. + myfmt.Errorf("%s", 3) + + // No match: function operand doesn't match. + myfmt.Printf("%s", x) + + // No match again, dot import. + Printf("%s", x) + + // Match. + myfmt.Fprint(os.Stderr, errors.New(x+"foo")) + + // No match: though this literally matches the template, + // fmt doesn't resolve to a package here. + var fmt struct{ Errorf func(string, string) } + fmt.Errorf("%s", x) + + // Recursive matching: + + // Match: both matches are well-typed, so both succeed. + errors.New(errors.New(x + "foo").Error()) + + // Outer match succeeds, inner doesn't: 3 has wrong type. + errors.New(myfmt.Errorf("%s", 3).Error()) + + // Inner match succeeds, outer doesn't: the inner replacement + // has the wrong type (error not string). + myfmt.Errorf("%s", errors.New(x+"foo")) +} diff --git a/refactor/eg/testdata/A2.go b/refactor/eg/testdata/A2.go new file mode 100644 index 00000000..3ae29ad7 --- /dev/null +++ b/refactor/eg/testdata/A2.go @@ -0,0 +1,12 @@ +// +build ignore + +package A2 + +// This refactoring causes addition of "errors" import. +// TODO(adonovan): fix: it should also remove "fmt". + +import myfmt "fmt" + +func example(n int) { + myfmt.Errorf("%s", "") +} diff --git a/refactor/eg/testdata/A2.golden b/refactor/eg/testdata/A2.golden new file mode 100644 index 00000000..5c2384b7 --- /dev/null +++ b/refactor/eg/testdata/A2.golden @@ -0,0 +1,15 @@ +// +build ignore + +package A2 + +// This refactoring causes addition of "errors" import. +// TODO(adonovan): fix: it should also remove "fmt". + +import ( + myfmt "fmt" + "errors" +) + +func example(n int) { + errors.New("") +} diff --git a/refactor/eg/testdata/B.template b/refactor/eg/testdata/B.template new file mode 100644 index 00000000..c16627bd --- /dev/null +++ b/refactor/eg/testdata/B.template @@ -0,0 +1,9 @@ +package template + +// Basic test of expression refactoring. +// (Types are not important in this case; it could be done with gofmt -r.) + +import "time" + +func before(t time.Time) time.Duration { return time.Now().Sub(t) } +func after(t time.Time) time.Duration { return time.Since(t) } diff --git a/refactor/eg/testdata/B1.go b/refactor/eg/testdata/B1.go new file mode 100644 index 00000000..8b525463 --- /dev/null +++ b/refactor/eg/testdata/B1.go @@ -0,0 +1,17 @@ +// +build ignore + +package B1 + +import "time" + +var startup = time.Now() + +func example() time.Duration { + before := time.Now() + time.Sleep(1) + return time.Now().Sub(before) +} + +func msSinceStartup() int64 { + return int64(time.Now().Sub(startup) / time.Millisecond) +} diff --git a/refactor/eg/testdata/B1.golden b/refactor/eg/testdata/B1.golden new file mode 100644 index 00000000..4d4da218 --- /dev/null +++ b/refactor/eg/testdata/B1.golden @@ -0,0 +1,17 @@ +// +build ignore + +package B1 + +import "time" + +var startup = time.Now() + +func example() time.Duration { + before := time.Now() + time.Sleep(1) + return time.Since(before) +} + +func msSinceStartup() int64 { + return int64(time.Since(startup) / time.Millisecond) +} diff --git a/refactor/eg/testdata/C.template b/refactor/eg/testdata/C.template new file mode 100644 index 00000000..f6f94d4a --- /dev/null +++ b/refactor/eg/testdata/C.template @@ -0,0 +1,10 @@ +package template + +// Test of repeated use of wildcard in pattern. + +// NB: multiple patterns would be required to handle variants such as +// s[:len(s)], s[x:len(s)], etc, since a wildcard can't match nothing at all. +// TODO(adonovan): support multiple templates in a single pass. + +func before(s string) string { return s[:len(s)] } +func after(s string) string { return s } diff --git a/refactor/eg/testdata/C1.go b/refactor/eg/testdata/C1.go new file mode 100644 index 00000000..523b3885 --- /dev/null +++ b/refactor/eg/testdata/C1.go @@ -0,0 +1,22 @@ +// +build ignore + +package C1 + +import "strings" + +func example() { + x := "foo" + println(x[:len(x)]) + + // Match, but the transformation is not sound w.r.t. possible side effects. + println(strings.Repeat("*", 3)[:len(strings.Repeat("*", 3))]) + + // No match, since second use of wildcard doesn't match first. + println(strings.Repeat("*", 3)[:len(strings.Repeat("*", 2))]) + + // Recursive match demonstrating bottom-up rewrite: + // only after the inner replacement occurs does the outer syntax match. + println((x[:len(x)])[:len(x[:len(x)])]) + // -> (x[:len(x)]) + // -> x +} diff --git a/refactor/eg/testdata/C1.golden b/refactor/eg/testdata/C1.golden new file mode 100644 index 00000000..ae7759d7 --- /dev/null +++ b/refactor/eg/testdata/C1.golden @@ -0,0 +1,22 @@ +// +build ignore + +package C1 + +import "strings" + +func example() { + x := "foo" + println(x) + + // Match, but the transformation is not sound w.r.t. possible side effects. + println(strings.Repeat("*", 3)) + + // No match, since second use of wildcard doesn't match first. + println(strings.Repeat("*", 3)[:len(strings.Repeat("*", 2))]) + + // Recursive match demonstrating bottom-up rewrite: + // only after the inner replacement occurs does the outer syntax match. + println(x) + // -> (x[:len(x)]) + // -> x +} diff --git a/refactor/eg/testdata/D.template b/refactor/eg/testdata/D.template new file mode 100644 index 00000000..6d3b6feb --- /dev/null +++ b/refactor/eg/testdata/D.template @@ -0,0 +1,8 @@ +package template + +import "fmt" + +// Test of semantic (not syntactic) matching of basic literals. + +func before() (int, error) { return fmt.Println(123, "a") } +func after() (int, error) { return fmt.Println(456, "!") } diff --git a/refactor/eg/testdata/D1.go b/refactor/eg/testdata/D1.go new file mode 100644 index 00000000..ae0a8060 --- /dev/null +++ b/refactor/eg/testdata/D1.go @@ -0,0 +1,12 @@ +// +build ignore + +package D1 + +import "fmt" + +func example() { + fmt.Println(123, "a") // match + fmt.Println(0x7b, `a`) // match + fmt.Println(0173, "\x61") // match + fmt.Println(100+20+3, "a"+"") // no match: constant expressions, but not basic literals +} diff --git a/refactor/eg/testdata/D1.golden b/refactor/eg/testdata/D1.golden new file mode 100644 index 00000000..3f2dc593 --- /dev/null +++ b/refactor/eg/testdata/D1.golden @@ -0,0 +1,12 @@ +// +build ignore + +package D1 + +import "fmt" + +func example() { + fmt.Println(456, "!") // match + fmt.Println(456, "!") // match + fmt.Println(456, "!") // match + fmt.Println(100+20+3, "a"+"") // no match: constant expressions, but not basic literals +} diff --git a/refactor/eg/testdata/E.template b/refactor/eg/testdata/E.template new file mode 100644 index 00000000..4bbbd113 --- /dev/null +++ b/refactor/eg/testdata/E.template @@ -0,0 +1,12 @@ +package template + +import ( + "fmt" + "log" + "os" +) + +// Replace call to void function by call to non-void function. + +func before(x interface{}) { log.Fatal(x) } +func after(x interface{}) { fmt.Fprintf(os.Stderr, "warning: %v", x) } diff --git a/refactor/eg/testdata/E1.go b/refactor/eg/testdata/E1.go new file mode 100644 index 00000000..3ea1793f --- /dev/null +++ b/refactor/eg/testdata/E1.go @@ -0,0 +1,9 @@ +// +build ignore + +package E1 + +import "log" + +func example() { + log.Fatal("oops") // match +} diff --git a/refactor/eg/testdata/E1.golden b/refactor/eg/testdata/E1.golden new file mode 100644 index 00000000..a0adfc8b --- /dev/null +++ b/refactor/eg/testdata/E1.golden @@ -0,0 +1,13 @@ +// +build ignore + +package E1 + +import ( + "log" + "os" + "fmt" +) + +func example() { + fmt.Fprintf(os.Stderr, "warning: %v", "oops") // match +} diff --git a/refactor/eg/testdata/bad_type.template b/refactor/eg/testdata/bad_type.template new file mode 100644 index 00000000..6d53d7e5 --- /dev/null +++ b/refactor/eg/testdata/bad_type.template @@ -0,0 +1,8 @@ +package template + +// Test in which replacement has a different type. + +const shouldFail = "int is not a safe replacement for string" + +func before() interface{} { return "three" } +func after() interface{} { return 3 } diff --git a/refactor/eg/testdata/expr_type_mismatch.template b/refactor/eg/testdata/expr_type_mismatch.template new file mode 100644 index 00000000..2c5c3f0d --- /dev/null +++ b/refactor/eg/testdata/expr_type_mismatch.template @@ -0,0 +1,15 @@ +package template + +import ( + "crypto/x509" + "fmt" +) + +// This test demonstrates a false negative: according to the language +// rules this replacement should be ok, but types.Assignable doesn't work +// in the expected way (elementwise assignability) for tuples. +// Perhaps that's even a type-checker bug? +const shouldFail = "(n int, err error) is not a safe replacement for (key interface{}, err error)" + +func before() (interface{}, error) { return x509.ParsePKCS8PrivateKey(nil) } +func after() (interface{}, error) { return fmt.Print() } diff --git a/refactor/eg/testdata/no_after_return.template b/refactor/eg/testdata/no_after_return.template new file mode 100644 index 00000000..536b01e6 --- /dev/null +++ b/refactor/eg/testdata/no_after_return.template @@ -0,0 +1,6 @@ +package template + +const shouldFail = "after: must contain a single statement" + +func before() int { return 0 } +func after() int { println(); return 0 } diff --git a/refactor/eg/testdata/no_before.template b/refactor/eg/testdata/no_before.template new file mode 100644 index 00000000..9205e667 --- /dev/null +++ b/refactor/eg/testdata/no_before.template @@ -0,0 +1,5 @@ +package template + +const shouldFail = "no 'before' func found in template" + +func Before() {} diff --git a/refactor/eg/testdata/type_mismatch.template b/refactor/eg/testdata/type_mismatch.template new file mode 100644 index 00000000..787c9a7a --- /dev/null +++ b/refactor/eg/testdata/type_mismatch.template @@ -0,0 +1,6 @@ +package template + +const shouldFail = "different signatures" + +func before() int { return 0 } +func after() string { return "" }