345 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			345 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Go
		
	
	
	
| // Copyright 2014 The Go Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package eg
 | |
| 
 | |
| // This file defines the AST rewriting pass.
 | |
| // Most of it was plundered directly from
 | |
| // $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution).
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"go/ast"
 | |
| 	"go/token"
 | |
| 	"go/types"
 | |
| 	"os"
 | |
| 	"reflect"
 | |
| 	"sort"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 
 | |
| 	"golang.org/x/tools/go/ast/astutil"
 | |
| )
 | |
| 
 | |
| // Transform applies the transformation to the specified parsed file,
 | |
| // whose type information is supplied in info, and returns the number
 | |
| // of replacements that were made.
 | |
| //
 | |
| // It mutates the AST in place (the identity of the root node is
 | |
| // unchanged), and may add nodes for which no type information is
 | |
| // available in info.
 | |
| //
 | |
| // Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go.
 | |
| //
 | |
| func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
 | |
| 	if !tr.seenInfos[info] {
 | |
| 		tr.seenInfos[info] = true
 | |
| 		mergeTypeInfo(tr.info, info)
 | |
| 	}
 | |
| 	tr.currentPkg = pkg
 | |
| 	tr.nsubsts = 0
 | |
| 
 | |
| 	if tr.verbose {
 | |
| 		fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
 | |
| 		fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
 | |
| 	}
 | |
| 
 | |
| 	var f func(rv reflect.Value) reflect.Value
 | |
| 	f = func(rv reflect.Value) reflect.Value {
 | |
| 		// don't bother if val is invalid to start with
 | |
| 		if !rv.IsValid() {
 | |
| 			return reflect.Value{}
 | |
| 		}
 | |
| 
 | |
| 		rv = apply(f, rv)
 | |
| 
 | |
| 		e := rvToExpr(rv)
 | |
| 		if e != nil {
 | |
| 			savedEnv := tr.env
 | |
| 			tr.env = make(map[string]ast.Expr) // inefficient!  Use a slice of k/v pairs
 | |
| 
 | |
| 			if tr.matchExpr(tr.before, e) {
 | |
| 				if tr.verbose {
 | |
| 					fmt.Fprintf(os.Stderr, "%s matches %s",
 | |
| 						astString(tr.fset, tr.before), astString(tr.fset, e))
 | |
| 					if len(tr.env) > 0 {
 | |
| 						fmt.Fprintf(os.Stderr, " with:")
 | |
| 						for name, ast := range tr.env {
 | |
| 							fmt.Fprintf(os.Stderr, " %s->%s",
 | |
| 								name, astString(tr.fset, ast))
 | |
| 						}
 | |
| 					}
 | |
| 					fmt.Fprintf(os.Stderr, "\n")
 | |
| 				}
 | |
| 				tr.nsubsts++
 | |
| 
 | |
| 				// Clone the replacement tree, performing parameter substitution.
 | |
| 				// We update all positions to n.Pos() to aid comment placement.
 | |
| 				rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
 | |
| 					reflect.ValueOf(e.Pos()))
 | |
| 			}
 | |
| 			tr.env = savedEnv
 | |
| 		}
 | |
| 
 | |
| 		return rv
 | |
| 	}
 | |
| 	file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File)
 | |
| 
 | |
| 	// By construction, the root node is unchanged.
 | |
| 	if file != file2 {
 | |
| 		panic("BUG")
 | |
| 	}
 | |
| 
 | |
| 	// Add any necessary imports.
 | |
| 	// TODO(adonovan): remove no-longer needed imports too.
 | |
| 	if tr.nsubsts > 0 {
 | |
| 		pkgs := make(map[string]*types.Package)
 | |
| 		for obj := range tr.importedObjs {
 | |
| 			pkgs[obj.Pkg().Path()] = obj.Pkg()
 | |
| 		}
 | |
| 
 | |
| 		for _, imp := range file.Imports {
 | |
| 			path, _ := strconv.Unquote(imp.Path.Value)
 | |
| 			delete(pkgs, path)
 | |
| 		}
 | |
| 		delete(pkgs, pkg.Path()) // don't import self
 | |
| 
 | |
| 		// NB: AddImport may completely replace the AST!
 | |
| 		// It thus renders info and tr.info no longer relevant to file.
 | |
| 		var paths []string
 | |
| 		for path := range pkgs {
 | |
| 			paths = append(paths, path)
 | |
| 		}
 | |
| 		sort.Strings(paths)
 | |
| 		for _, path := range paths {
 | |
| 			astutil.AddImport(tr.fset, file, path)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	tr.currentPkg = nil
 | |
| 
 | |
| 	return tr.nsubsts
 | |
| }
 | |
| 
 | |
| // setValue is a wrapper for x.SetValue(y); it protects
 | |
| // the caller from panics if x cannot be changed to y.
 | |
| func setValue(x, y reflect.Value) {
 | |
| 	// don't bother if y is invalid to start with
 | |
| 	if !y.IsValid() {
 | |
| 		return
 | |
| 	}
 | |
| 	defer func() {
 | |
| 		if x := recover(); x != nil {
 | |
| 			if s, ok := x.(string); ok &&
 | |
| 				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
 | |
| 				// x cannot be set to y - ignore this rewrite
 | |
| 				return
 | |
| 			}
 | |
| 			panic(x)
 | |
| 		}
 | |
| 	}()
 | |
| 	x.Set(y)
 | |
| }
 | |
| 
 | |
| // Values/types for special cases.
 | |
| var (
 | |
| 	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
 | |
| 	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
 | |
| 
 | |
| 	identType        = reflect.TypeOf((*ast.Ident)(nil))
 | |
| 	selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
 | |
| 	objectPtrType    = reflect.TypeOf((*ast.Object)(nil))
 | |
| 	positionType     = reflect.TypeOf(token.NoPos)
 | |
| 	callExprType     = reflect.TypeOf((*ast.CallExpr)(nil))
 | |
| 	scopePtrType     = reflect.TypeOf((*ast.Scope)(nil))
 | |
| )
 | |
| 
 | |
| // apply replaces each AST field x in val with f(x), returning val.
 | |
| // To avoid extra conversions, f operates on the reflect.Value form.
 | |
| func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
 | |
| 	if !val.IsValid() {
 | |
| 		return reflect.Value{}
 | |
| 	}
 | |
| 
 | |
| 	// *ast.Objects introduce cycles and are likely incorrect after
 | |
| 	// rewrite; don't follow them but replace with nil instead
 | |
| 	if val.Type() == objectPtrType {
 | |
| 		return objectPtrNil
 | |
| 	}
 | |
| 
 | |
| 	// similarly for scopes: they are likely incorrect after a rewrite;
 | |
| 	// replace them with nil
 | |
| 	if val.Type() == scopePtrType {
 | |
| 		return scopePtrNil
 | |
| 	}
 | |
| 
 | |
| 	switch v := reflect.Indirect(val); v.Kind() {
 | |
| 	case reflect.Slice:
 | |
| 		for i := 0; i < v.Len(); i++ {
 | |
| 			e := v.Index(i)
 | |
| 			setValue(e, f(e))
 | |
| 		}
 | |
| 	case reflect.Struct:
 | |
| 		for i := 0; i < v.NumField(); i++ {
 | |
| 			e := v.Field(i)
 | |
| 			setValue(e, f(e))
 | |
| 		}
 | |
| 	case reflect.Interface:
 | |
| 		e := v.Elem()
 | |
| 		setValue(v, f(e))
 | |
| 	}
 | |
| 	return val
 | |
| }
 | |
| 
 | |
| // subst returns a copy of (replacement) pattern with values from env
 | |
| // substituted in place of wildcards and pos used as the position of
 | |
| // tokens from the pattern.  if env == nil, subst returns a copy of
 | |
| // pattern and doesn't change the line number information.
 | |
| func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
 | |
| 	if !pattern.IsValid() {
 | |
| 		return reflect.Value{}
 | |
| 	}
 | |
| 
 | |
| 	// *ast.Objects introduce cycles and are likely incorrect after
 | |
| 	// rewrite; don't follow them but replace with nil instead
 | |
| 	if pattern.Type() == objectPtrType {
 | |
| 		return objectPtrNil
 | |
| 	}
 | |
| 
 | |
| 	// similarly for scopes: they are likely incorrect after a rewrite;
 | |
| 	// replace them with nil
 | |
| 	if pattern.Type() == scopePtrType {
 | |
| 		return scopePtrNil
 | |
| 	}
 | |
| 
 | |
| 	// Wildcard gets replaced with map value.
 | |
| 	if env != nil && pattern.Type() == identType {
 | |
| 		id := pattern.Interface().(*ast.Ident)
 | |
| 		if old, ok := env[id.Name]; ok {
 | |
| 			return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Emit qualified identifiers in the pattern by appropriate
 | |
| 	// (possibly qualified) identifier in the input.
 | |
| 	//
 | |
| 	// The template cannot contain dot imports, so all identifiers
 | |
| 	// for imported objects are explicitly qualified.
 | |
| 	//
 | |
| 	// We assume (unsoundly) that there are no dot or named
 | |
| 	// imports in the input code, nor are any imported package
 | |
| 	// names shadowed, so the usual normal qualified identifier
 | |
| 	// syntax may be used.
 | |
| 	// TODO(adonovan): fix: avoid this assumption.
 | |
| 	//
 | |
| 	// A refactoring may be applied to a package referenced by the
 | |
| 	// template.  Objects belonging to the current package are
 | |
| 	// denoted by unqualified identifiers.
 | |
| 	//
 | |
| 	if tr.importedObjs != nil && pattern.Type() == selectorExprType {
 | |
| 		obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info)
 | |
| 		if obj != nil {
 | |
| 			if sel, ok := tr.importedObjs[obj]; ok {
 | |
| 				var id ast.Expr
 | |
| 				if obj.Pkg() == tr.currentPkg {
 | |
| 					id = sel.Sel // unqualified
 | |
| 				} else {
 | |
| 					id = sel // pkg-qualified
 | |
| 				}
 | |
| 
 | |
| 				// Return a clone of id.
 | |
| 				saved := tr.importedObjs
 | |
| 				tr.importedObjs = nil // break cycle
 | |
| 				r := tr.subst(nil, reflect.ValueOf(id), pos)
 | |
| 				tr.importedObjs = saved
 | |
| 				return r
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if pos.IsValid() && pattern.Type() == positionType {
 | |
| 		// use new position only if old position was valid in the first place
 | |
| 		if old := pattern.Interface().(token.Pos); !old.IsValid() {
 | |
| 			return pattern
 | |
| 		}
 | |
| 		return pos
 | |
| 	}
 | |
| 
 | |
| 	// Otherwise copy.
 | |
| 	switch p := pattern; p.Kind() {
 | |
| 	case reflect.Slice:
 | |
| 		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
 | |
| 		for i := 0; i < p.Len(); i++ {
 | |
| 			v.Index(i).Set(tr.subst(env, p.Index(i), pos))
 | |
| 		}
 | |
| 		return v
 | |
| 
 | |
| 	case reflect.Struct:
 | |
| 		v := reflect.New(p.Type()).Elem()
 | |
| 		for i := 0; i < p.NumField(); i++ {
 | |
| 			v.Field(i).Set(tr.subst(env, p.Field(i), pos))
 | |
| 		}
 | |
| 		return v
 | |
| 
 | |
| 	case reflect.Ptr:
 | |
| 		v := reflect.New(p.Type()).Elem()
 | |
| 		if elem := p.Elem(); elem.IsValid() {
 | |
| 			v.Set(tr.subst(env, elem, pos).Addr())
 | |
| 		}
 | |
| 
 | |
| 		// Duplicate type information for duplicated ast.Expr.
 | |
| 		// All ast.Node implementations are *structs,
 | |
| 		// so this case catches them all.
 | |
| 		if e := rvToExpr(v); e != nil {
 | |
| 			updateTypeInfo(tr.info, e, p.Interface().(ast.Expr))
 | |
| 		}
 | |
| 		return v
 | |
| 
 | |
| 	case reflect.Interface:
 | |
| 		v := reflect.New(p.Type()).Elem()
 | |
| 		if elem := p.Elem(); elem.IsValid() {
 | |
| 			v.Set(tr.subst(env, elem, pos))
 | |
| 		}
 | |
| 		return v
 | |
| 	}
 | |
| 
 | |
| 	return pattern
 | |
| }
 | |
| 
 | |
| // -- utilities -------------------------------------------------------
 | |
| 
 | |
| func rvToExpr(rv reflect.Value) ast.Expr {
 | |
| 	if rv.CanInterface() {
 | |
| 		if e, ok := rv.Interface().(ast.Expr); ok {
 | |
| 			return e
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // updateTypeInfo duplicates type information for the existing AST old
 | |
| // so that it also applies to duplicated AST new.
 | |
| func updateTypeInfo(info *types.Info, new, old ast.Expr) {
 | |
| 	switch new := new.(type) {
 | |
| 	case *ast.Ident:
 | |
| 		orig := old.(*ast.Ident)
 | |
| 		if obj, ok := info.Defs[orig]; ok {
 | |
| 			info.Defs[new] = obj
 | |
| 		}
 | |
| 		if obj, ok := info.Uses[orig]; ok {
 | |
| 			info.Uses[new] = obj
 | |
| 		}
 | |
| 
 | |
| 	case *ast.SelectorExpr:
 | |
| 		orig := old.(*ast.SelectorExpr)
 | |
| 		if sel, ok := info.Selections[orig]; ok {
 | |
| 			info.Selections[new] = sel
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if tv, ok := info.Types[old]; ok {
 | |
| 		info.Types[new] = tv
 | |
| 	}
 | |
| }
 |