466 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			466 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
// Copyright 2013 The Go Authors. All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package oracle
 | 
						|
 | 
						|
// This file defines oracle.Query, the entry point for the oracle tool.
 | 
						|
// The actual executable is defined in cmd/oracle.
 | 
						|
 | 
						|
// TODO(adonovan): new query: show all statements that may update the
 | 
						|
// selected lvalue (local, global, field, etc).
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	encjson "encoding/json"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"go/ast"
 | 
						|
	"go/build"
 | 
						|
	"go/printer"
 | 
						|
	"go/token"
 | 
						|
	"io"
 | 
						|
	"os"
 | 
						|
	"path/filepath"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"code.google.com/p/go.tools/go/types"
 | 
						|
	"code.google.com/p/go.tools/importer"
 | 
						|
	"code.google.com/p/go.tools/oracle/json"
 | 
						|
	"code.google.com/p/go.tools/pointer"
 | 
						|
	"code.google.com/p/go.tools/ssa"
 | 
						|
)
 | 
						|
 | 
						|
type oracle struct {
 | 
						|
	out    io.Writer      // standard output
 | 
						|
	prog   *ssa.Program   // the SSA program [only populated if need&SSA]
 | 
						|
	config pointer.Config // pointer analysis configuration
 | 
						|
 | 
						|
	// need&(Pos|ExactPos):
 | 
						|
	startPos, endPos token.Pos             // source extent of query
 | 
						|
	queryPkgInfo     *importer.PackageInfo // type info for the queried package
 | 
						|
	queryPath        []ast.Node            // AST path from query node to root of ast.File
 | 
						|
 | 
						|
	// need&AllTypeInfo
 | 
						|
	typeInfo map[*types.Package]*importer.PackageInfo // type info for all ASTs in the program
 | 
						|
 | 
						|
	timers map[string]time.Duration // phase timing information
 | 
						|
}
 | 
						|
 | 
						|
// A set of bits indicating the analytical requirements of each mode.
 | 
						|
//
 | 
						|
// Typed ASTs for the whole program are always constructed
 | 
						|
// transiently; they are retained only for the queried package unless
 | 
						|
// AllTypeInfo is set.
 | 
						|
const (
 | 
						|
	Pos         = 1 << iota // needs a position
 | 
						|
	ExactPos                // needs an exact AST selection; implies Pos
 | 
						|
	AllTypeInfo             // needs to retain type info for all ASTs in the program
 | 
						|
	SSA                     // needs ssa.Packages for whole program
 | 
						|
	PTA         = SSA       // needs pointer analysis
 | 
						|
)
 | 
						|
 | 
						|
type modeInfo struct {
 | 
						|
	needs int
 | 
						|
	impl  func(*oracle) (queryResult, error)
 | 
						|
}
 | 
						|
 | 
						|
var modes = map[string]modeInfo{
 | 
						|
	"callees":    modeInfo{PTA | ExactPos, callees},
 | 
						|
	"callers":    modeInfo{PTA | Pos, callers},
 | 
						|
	"callgraph":  modeInfo{PTA, callgraph},
 | 
						|
	"callstack":  modeInfo{PTA | Pos, callstack},
 | 
						|
	"describe":   modeInfo{PTA | ExactPos, describe},
 | 
						|
	"freevars":   modeInfo{Pos, freevars},
 | 
						|
	"implements": modeInfo{Pos, implements},
 | 
						|
	"peers":      modeInfo{PTA | Pos, peers},
 | 
						|
	"referrers":  modeInfo{AllTypeInfo | Pos, referrers},
 | 
						|
}
 | 
						|
 | 
						|
type printfFunc func(pos interface{}, format string, args ...interface{})
 | 
						|
 | 
						|
// queryResult is the interface of each query-specific result type.
 | 
						|
type queryResult interface {
 | 
						|
	toJSON(res *json.Result, fset *token.FileSet)
 | 
						|
	display(printf printfFunc)
 | 
						|
}
 | 
						|
 | 
						|
type warning struct {
 | 
						|
	pos    token.Pos
 | 
						|
	format string
 | 
						|
	args   []interface{}
 | 
						|
}
 | 
						|
 | 
						|
// A Result encapsulates the result of an oracle.Query.
 | 
						|
//
 | 
						|
// Result instances implement the json.Marshaler interface, i.e. they
 | 
						|
// can be JSON-serialized.
 | 
						|
type Result struct {
 | 
						|
	fset *token.FileSet
 | 
						|
	// fprintf is a closure over the oracle's fileset and start/end position.
 | 
						|
	fprintf  func(w io.Writer, pos interface{}, format string, args ...interface{})
 | 
						|
	q        queryResult // the query-specific result
 | 
						|
	mode     string      // query mode
 | 
						|
	warnings []warning   // pointer analysis warnings
 | 
						|
}
 | 
						|
 | 
						|
func (res *Result) MarshalJSON() ([]byte, error) {
 | 
						|
	resj := &json.Result{Mode: res.mode}
 | 
						|
	res.q.toJSON(resj, res.fset)
 | 
						|
	for _, w := range res.warnings {
 | 
						|
		resj.Warnings = append(resj.Warnings, json.PTAWarning{
 | 
						|
			Pos:     res.fset.Position(w.pos).String(),
 | 
						|
			Message: fmt.Sprintf(w.format, w.args...),
 | 
						|
		})
 | 
						|
	}
 | 
						|
	return encjson.Marshal(resj)
 | 
						|
}
 | 
						|
 | 
						|
// Query runs the oracle.
 | 
						|
// args specify the main package in importer.CreatePackageFromArgs syntax.
 | 
						|
// mode is the query mode ("callers", etc).
 | 
						|
// pos is the selection in parseQueryPos() syntax.
 | 
						|
// ptalog is the (optional) pointer-analysis log file.
 | 
						|
// buildContext is the optional configuration for locating packages.
 | 
						|
//
 | 
						|
func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *build.Context) (*Result, error) {
 | 
						|
	minfo, ok := modes[mode]
 | 
						|
	if !ok {
 | 
						|
		if mode == "" {
 | 
						|
			return nil, errors.New("you must specify a -mode of query to perform")
 | 
						|
		}
 | 
						|
		return nil, fmt.Errorf("invalid mode type: %q", mode)
 | 
						|
	}
 | 
						|
 | 
						|
	imp := importer.New(&importer.Config{Build: buildContext})
 | 
						|
	o := &oracle{
 | 
						|
		prog:   ssa.NewProgram(imp.Fset, 0),
 | 
						|
		timers: make(map[string]time.Duration),
 | 
						|
	}
 | 
						|
	o.config.Log = ptalog
 | 
						|
 | 
						|
	var res Result
 | 
						|
	o.config.Warn = func(pos token.Pos, format string, args ...interface{}) {
 | 
						|
		res.warnings = append(res.warnings, warning{pos, format, args})
 | 
						|
	}
 | 
						|
 | 
						|
	// Phase timing diagnostics.
 | 
						|
	if false {
 | 
						|
		defer func() {
 | 
						|
			fmt.Println()
 | 
						|
			for name, duration := range o.timers {
 | 
						|
				fmt.Printf("# %-30s %s\n", name, duration)
 | 
						|
			}
 | 
						|
		}()
 | 
						|
	}
 | 
						|
 | 
						|
	// Load/parse/type-check program from args.
 | 
						|
	start := time.Now()
 | 
						|
	initialPkgInfos, args, err := imp.LoadInitialPackages(args)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err // I/O or parser error
 | 
						|
	}
 | 
						|
	if len(args) > 0 {
 | 
						|
		return nil, fmt.Errorf("surplus arguments: %q", args)
 | 
						|
	}
 | 
						|
	o.timers["load/parse/type"] = time.Since(start)
 | 
						|
 | 
						|
	// Retain type info for all ASTs in the program.
 | 
						|
	if minfo.needs&AllTypeInfo != 0 {
 | 
						|
		m := make(map[*types.Package]*importer.PackageInfo)
 | 
						|
		for _, p := range imp.AllPackages() {
 | 
						|
			m[p.Pkg] = p
 | 
						|
		}
 | 
						|
		o.typeInfo = m
 | 
						|
	}
 | 
						|
 | 
						|
	// Parse the source query position.
 | 
						|
	if minfo.needs&(Pos|ExactPos) != 0 {
 | 
						|
		var err error
 | 
						|
		o.startPos, o.endPos, err = parseQueryPos(o.prog.Fset, pos)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		var exact bool
 | 
						|
		o.queryPkgInfo, o.queryPath, exact = imp.PathEnclosingInterval(o.startPos, o.endPos)
 | 
						|
		if o.queryPath == nil {
 | 
						|
			return nil, o.errorf(false, "no syntax here")
 | 
						|
		}
 | 
						|
		if minfo.needs&ExactPos != 0 && !exact {
 | 
						|
			return nil, o.errorf(o.queryPath[0], "ambiguous selection within %s",
 | 
						|
				importer.NodeDescription(o.queryPath[0]))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Create SSA package for the initial package and its dependencies.
 | 
						|
	if minfo.needs&SSA != 0 {
 | 
						|
		start = time.Now()
 | 
						|
 | 
						|
		// Create SSA packages.
 | 
						|
		if err := o.prog.CreatePackages(imp); err != nil {
 | 
						|
			return nil, o.errorf(false, "%s", err)
 | 
						|
		}
 | 
						|
 | 
						|
		// Initial packages (specified on command line)
 | 
						|
		for _, info := range initialPkgInfos {
 | 
						|
			initialPkg := o.prog.Package(info.Pkg)
 | 
						|
 | 
						|
			// Add package to the pointer analysis scope.
 | 
						|
			if initialPkg.Func("main") == nil {
 | 
						|
				// TODO(adonovan): to simulate 'go test' more faithfully, we
 | 
						|
				// should build a single synthetic testmain package,
 | 
						|
				// not synthetic main functions to many packages.
 | 
						|
				if initialPkg.CreateTestMainFunction() == nil {
 | 
						|
					return nil, o.errorf(false, "analysis scope has no main() entry points")
 | 
						|
				}
 | 
						|
			}
 | 
						|
			o.config.Mains = append(o.config.Mains, initialPkg)
 | 
						|
		}
 | 
						|
 | 
						|
		// Query package.
 | 
						|
		if o.queryPkgInfo != nil {
 | 
						|
			pkg := o.prog.Package(o.queryPkgInfo.Pkg)
 | 
						|
			pkg.SetDebugMode(true)
 | 
						|
			pkg.Build()
 | 
						|
		}
 | 
						|
 | 
						|
		o.timers["SSA-create"] = time.Since(start)
 | 
						|
	}
 | 
						|
 | 
						|
	// SSA is built and we have query{Path,PkgInfo}.
 | 
						|
	// Release the other ASTs and type info to the GC.
 | 
						|
	imp = nil
 | 
						|
 | 
						|
	res.q, err = minfo.impl(o)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	res.mode = mode
 | 
						|
	res.fset = o.prog.Fset
 | 
						|
	res.fprintf = o.fprintf // captures o.prog, o.{start,end}Pos for later printing
 | 
						|
	return &res, nil
 | 
						|
}
 | 
						|
 | 
						|
// WriteTo writes the oracle query result res to out in a compiler diagnostic format.
 | 
						|
func (res *Result) WriteTo(out io.Writer) {
 | 
						|
	printf := func(pos interface{}, format string, args ...interface{}) {
 | 
						|
		res.fprintf(out, pos, format, args...)
 | 
						|
	}
 | 
						|
	res.q.display(printf)
 | 
						|
 | 
						|
	// Print warnings after the main output.
 | 
						|
	if res.warnings != nil {
 | 
						|
		fmt.Fprintln(out, "\nPointer analysis warnings:")
 | 
						|
		for _, w := range res.warnings {
 | 
						|
			printf(w.pos, "warning: "+w.format, w.args...)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// ---------- Utilities ----------
 | 
						|
 | 
						|
// buildSSA constructs the SSA representation of Go-source function bodies.
 | 
						|
// Not needed in simpler modes, e.g. freevars.
 | 
						|
//
 | 
						|
func buildSSA(o *oracle) {
 | 
						|
	start := time.Now()
 | 
						|
	o.prog.BuildAll()
 | 
						|
	o.timers["SSA-build"] = time.Since(start)
 | 
						|
}
 | 
						|
 | 
						|
// ptrAnalysis runs the pointer analysis and returns the synthetic
 | 
						|
// root of the callgraph.
 | 
						|
//
 | 
						|
func ptrAnalysis(o *oracle) pointer.CallGraphNode {
 | 
						|
	start := time.Now()
 | 
						|
	root := pointer.Analyze(&o.config)
 | 
						|
	o.timers["pointer analysis"] = time.Since(start)
 | 
						|
	return root
 | 
						|
}
 | 
						|
 | 
						|
// parseOctothorpDecimal returns the numeric value if s matches "#%d",
 | 
						|
// otherwise -1.
 | 
						|
func parseOctothorpDecimal(s string) int {
 | 
						|
	if s != "" && s[0] == '#' {
 | 
						|
		if s, err := strconv.ParseInt(s[1:], 10, 32); err == nil {
 | 
						|
			return int(s)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return -1
 | 
						|
}
 | 
						|
 | 
						|
// parseQueryPos parses a string of the form "file:pos" or
 | 
						|
// file:start,end" where pos, start, end match #%d and represent byte
 | 
						|
// offsets, and returns the extent to which it refers.
 | 
						|
//
 | 
						|
// (Numbers without a '#' prefix are reserved for future use,
 | 
						|
// e.g. to indicate line/column positions.)
 | 
						|
//
 | 
						|
func parseQueryPos(fset *token.FileSet, queryPos string) (start, end token.Pos, err error) {
 | 
						|
	if queryPos == "" {
 | 
						|
		err = fmt.Errorf("no source position specified (-pos flag)")
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	colon := strings.LastIndex(queryPos, ":")
 | 
						|
	if colon < 0 {
 | 
						|
		err = fmt.Errorf("invalid source position -pos=%q", queryPos)
 | 
						|
		return
 | 
						|
	}
 | 
						|
	filename, offset := queryPos[:colon], queryPos[colon+1:]
 | 
						|
	startOffset := -1
 | 
						|
	endOffset := -1
 | 
						|
	if hyphen := strings.Index(offset, ","); hyphen < 0 {
 | 
						|
		// e.g. "foo.go:#123"
 | 
						|
		startOffset = parseOctothorpDecimal(offset)
 | 
						|
		endOffset = startOffset
 | 
						|
	} else {
 | 
						|
		// e.g. "foo.go:#123,#456"
 | 
						|
		startOffset = parseOctothorpDecimal(offset[:hyphen])
 | 
						|
		endOffset = parseOctothorpDecimal(offset[hyphen+1:])
 | 
						|
	}
 | 
						|
	if startOffset < 0 || endOffset < 0 {
 | 
						|
		err = fmt.Errorf("invalid -pos offset %q", offset)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	var file *token.File
 | 
						|
	fset.Iterate(func(f *token.File) bool {
 | 
						|
		if sameFile(filename, f.Name()) {
 | 
						|
			// (f.Name() is absolute)
 | 
						|
			file = f
 | 
						|
			return false // done
 | 
						|
		}
 | 
						|
		return true // continue
 | 
						|
	})
 | 
						|
	if file == nil {
 | 
						|
		err = fmt.Errorf("couldn't find file containing position -pos=%q", queryPos)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	// Range check [start..end], inclusive of both end-points.
 | 
						|
 | 
						|
	if 0 <= startOffset && startOffset <= file.Size() {
 | 
						|
		start = file.Pos(int(startOffset))
 | 
						|
	} else {
 | 
						|
		err = fmt.Errorf("start position is beyond end of file -pos=%q", queryPos)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	if 0 <= endOffset && endOffset <= file.Size() {
 | 
						|
		end = file.Pos(int(endOffset))
 | 
						|
	} else {
 | 
						|
		err = fmt.Errorf("end position is beyond end of file -pos=%q", queryPos)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
// sameFile returns true if x and y have the same basename and denote
 | 
						|
// the same file.
 | 
						|
//
 | 
						|
func sameFile(x, y string) bool {
 | 
						|
	if filepath.Base(x) == filepath.Base(y) { // (optimisation)
 | 
						|
		if xi, err := os.Stat(x); err == nil {
 | 
						|
			if yi, err := os.Stat(y); err == nil {
 | 
						|
				return os.SameFile(xi, yi)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
// unparen returns e with any enclosing parentheses stripped.
 | 
						|
func unparen(e ast.Expr) ast.Expr {
 | 
						|
	for {
 | 
						|
		p, ok := e.(*ast.ParenExpr)
 | 
						|
		if !ok {
 | 
						|
			break
 | 
						|
		}
 | 
						|
		e = p.X
 | 
						|
	}
 | 
						|
	return e
 | 
						|
}
 | 
						|
 | 
						|
// deref returns a pointer's element type; otherwise it returns typ.
 | 
						|
func deref(typ types.Type) types.Type {
 | 
						|
	if p, ok := typ.Underlying().(*types.Pointer); ok {
 | 
						|
		return p.Elem()
 | 
						|
	}
 | 
						|
	return typ
 | 
						|
}
 | 
						|
 | 
						|
// fprintf prints to w a message of the form "location: message\n"
 | 
						|
// where location is derived from pos.
 | 
						|
//
 | 
						|
// pos must be one of:
 | 
						|
//    - a token.Pos, denoting a position
 | 
						|
//    - an ast.Node, denoting an interval
 | 
						|
//    - anything with a Pos() method:
 | 
						|
//         ssa.Member, ssa.Value, ssa.Instruction, types.Object, pointer.Label, etc.
 | 
						|
//    - a bool, meaning the extent [o.startPos, o.endPos) of the user's query.
 | 
						|
//      (the value is ignored)
 | 
						|
//    - nil, meaning no position at all.
 | 
						|
//
 | 
						|
// The output format is is compatible with the 'gnu'
 | 
						|
// compilation-error-regexp in Emacs' compilation mode.
 | 
						|
// TODO(adonovan): support other editors.
 | 
						|
//
 | 
						|
func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...interface{}) {
 | 
						|
	var start, end token.Pos
 | 
						|
	switch pos := pos.(type) {
 | 
						|
	case ast.Node:
 | 
						|
		start = pos.Pos()
 | 
						|
		end = pos.End()
 | 
						|
	case token.Pos:
 | 
						|
		start = pos
 | 
						|
		end = start
 | 
						|
	case interface {
 | 
						|
		Pos() token.Pos
 | 
						|
	}:
 | 
						|
		start = pos.Pos()
 | 
						|
		end = start
 | 
						|
	case bool:
 | 
						|
		start = o.startPos
 | 
						|
		end = o.endPos
 | 
						|
	case nil:
 | 
						|
		// no-op
 | 
						|
	default:
 | 
						|
		panic(fmt.Sprintf("invalid pos: %T", pos))
 | 
						|
	}
 | 
						|
 | 
						|
	if sp := o.prog.Fset.Position(start); start == end {
 | 
						|
		// (prints "-: " for token.NoPos)
 | 
						|
		fmt.Fprintf(w, "%s: ", sp)
 | 
						|
	} else {
 | 
						|
		ep := o.prog.Fset.Position(end)
 | 
						|
		// The -1 below is a concession to Emacs's broken use of
 | 
						|
		// inclusive (not half-open) intervals.
 | 
						|
		// Other editors may not want it.
 | 
						|
		// TODO(adonovan): add an -editor=vim|emacs|acme|auto
 | 
						|
		// flag; auto uses EMACS=t / VIM=... / etc env vars.
 | 
						|
		fmt.Fprintf(w, "%s:%d.%d-%d.%d: ",
 | 
						|
			sp.Filename, sp.Line, sp.Column, ep.Line, ep.Column-1)
 | 
						|
	}
 | 
						|
	fmt.Fprintf(w, format, args...)
 | 
						|
	io.WriteString(w, "\n")
 | 
						|
}
 | 
						|
 | 
						|
// errorf is like fprintf, but returns a formatted error string.
 | 
						|
func (o *oracle) errorf(pos interface{}, format string, args ...interface{}) error {
 | 
						|
	var buf bytes.Buffer
 | 
						|
	o.fprintf(&buf, pos, format, args...)
 | 
						|
	return errors.New(buf.String())
 | 
						|
}
 | 
						|
 | 
						|
// printNode returns the pretty-printed syntax of n.
 | 
						|
func (o *oracle) printNode(n ast.Node) string {
 | 
						|
	var buf bytes.Buffer
 | 
						|
	printer.Fprint(&buf, o.prog.Fset, n)
 | 
						|
	return buf.String()
 | 
						|
}
 |