295 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			295 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Go
		
	
	
	
| // Copyright 2014 The Go Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package oracle
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"go/ast"
 | |
| 	"go/token"
 | |
| 	"sort"
 | |
| 
 | |
| 	"golang.org/x/tools/go/ast/astutil"
 | |
| 	"golang.org/x/tools/go/ssa"
 | |
| 	"golang.org/x/tools/go/ssa/ssautil"
 | |
| 	"golang.org/x/tools/go/types"
 | |
| 	"golang.org/x/tools/oracle/serial"
 | |
| )
 | |
| 
 | |
| var builtinErrorType = types.Universe.Lookup("error").Type()
 | |
| 
 | |
| // whicherrs takes an position to an error and tries to find all types, constants
 | |
| // and global value which a given error can point to and which can be checked from the
 | |
| // scope where the error lives.
 | |
| // In short, it returns a list of things that can be checked against in order to handle
 | |
| // an error properly.
 | |
| //
 | |
| // TODO(dmorsing): figure out if fields in errors like *os.PathError.Err
 | |
| // can be queried recursively somehow.
 | |
| func whicherrs(o *Oracle, qpos *QueryPos) (queryResult, error) {
 | |
| 	path, action := findInterestingNode(qpos.info, qpos.path)
 | |
| 	if action != actionExpr {
 | |
| 		return nil, fmt.Errorf("whicherrs wants an expression; got %s",
 | |
| 			astutil.NodeDescription(qpos.path[0]))
 | |
| 	}
 | |
| 	var expr ast.Expr
 | |
| 	var obj types.Object
 | |
| 	switch n := path[0].(type) {
 | |
| 	case *ast.ValueSpec:
 | |
| 		// ambiguous ValueSpec containing multiple names
 | |
| 		return nil, fmt.Errorf("multiple value specification")
 | |
| 	case *ast.Ident:
 | |
| 		obj = qpos.info.ObjectOf(n)
 | |
| 		expr = n
 | |
| 	case ast.Expr:
 | |
| 		expr = n
 | |
| 	default:
 | |
| 		return nil, fmt.Errorf("unexpected AST for expr: %T", n)
 | |
| 	}
 | |
| 
 | |
| 	typ := qpos.info.TypeOf(expr)
 | |
| 	if !types.Identical(typ, builtinErrorType) {
 | |
| 		return nil, fmt.Errorf("selection is not an expression of type 'error'")
 | |
| 	}
 | |
| 	// Determine the ssa.Value for the expression.
 | |
| 	var value ssa.Value
 | |
| 	var err error
 | |
| 	if obj != nil {
 | |
| 		// def/ref of func/var object
 | |
| 		value, _, err = ssaValueForIdent(o.prog, qpos.info, obj, path)
 | |
| 	} else {
 | |
| 		value, _, err = ssaValueForExpr(o.prog, qpos.info, path)
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return nil, err // e.g. trivially dead code
 | |
| 	}
 | |
| 	buildSSA(o)
 | |
| 
 | |
| 	globals := findVisibleErrs(o.prog, qpos)
 | |
| 	constants := findVisibleConsts(o.prog, qpos)
 | |
| 
 | |
| 	res := &whicherrsResult{
 | |
| 		qpos:   qpos,
 | |
| 		errpos: expr.Pos(),
 | |
| 	}
 | |
| 
 | |
| 	// Find the instruction which initialized the
 | |
| 	// global error. If more than one instruction has stored to the global
 | |
| 	// remove the global from the set of values that we want to query.
 | |
| 	allFuncs := ssautil.AllFunctions(o.prog)
 | |
| 	for fn := range allFuncs {
 | |
| 		for _, b := range fn.Blocks {
 | |
| 			for _, instr := range b.Instrs {
 | |
| 				store, ok := instr.(*ssa.Store)
 | |
| 				if !ok {
 | |
| 					continue
 | |
| 				}
 | |
| 				gval, ok := store.Addr.(*ssa.Global)
 | |
| 				if !ok {
 | |
| 					continue
 | |
| 				}
 | |
| 				gbl, ok := globals[gval]
 | |
| 				if !ok {
 | |
| 					continue
 | |
| 				}
 | |
| 				// we already found a store to this global
 | |
| 				// The normal error define is just one store in the init
 | |
| 				// so we just remove this global from the set we want to query
 | |
| 				if gbl != nil {
 | |
| 					delete(globals, gval)
 | |
| 				}
 | |
| 				globals[gval] = store.Val
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	o.ptaConfig.AddQuery(value)
 | |
| 	for _, v := range globals {
 | |
| 		o.ptaConfig.AddQuery(v)
 | |
| 	}
 | |
| 
 | |
| 	ptares := ptrAnalysis(o)
 | |
| 	valueptr := ptares.Queries[value]
 | |
| 	for g, v := range globals {
 | |
| 		ptr, ok := ptares.Queries[v]
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 		if !ptr.MayAlias(valueptr) {
 | |
| 			continue
 | |
| 		}
 | |
| 		res.globals = append(res.globals, g)
 | |
| 	}
 | |
| 	pts := valueptr.PointsTo()
 | |
| 	dedup := make(map[*ssa.NamedConst]bool)
 | |
| 	for _, label := range pts.Labels() {
 | |
| 		// These values are either MakeInterfaces or reflect
 | |
| 		// generated interfaces. For the purposes of this
 | |
| 		// analysis, we don't care about reflect generated ones
 | |
| 		makeiface, ok := label.Value().(*ssa.MakeInterface)
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 		constval, ok := makeiface.X.(*ssa.Const)
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 		c := constants[*constval]
 | |
| 		if c != nil && !dedup[c] {
 | |
| 			dedup[c] = true
 | |
| 			res.consts = append(res.consts, c)
 | |
| 		}
 | |
| 	}
 | |
| 	concs := pts.DynamicTypes()
 | |
| 	concs.Iterate(func(conc types.Type, _ interface{}) {
 | |
| 		// go/types is a bit annoying here.
 | |
| 		// We want to find all the types that we can
 | |
| 		// typeswitch or assert to. This means finding out
 | |
| 		// if the type pointed to can be seen by us.
 | |
| 		//
 | |
| 		// For the purposes of this analysis, the type is always
 | |
| 		// either a Named type or a pointer to one.
 | |
| 		// There are cases where error can be implemented
 | |
| 		// by unnamed types, but in that case, we can't assert to
 | |
| 		// it, so we don't care about it for this analysis.
 | |
| 		var name *types.TypeName
 | |
| 		switch t := conc.(type) {
 | |
| 		case *types.Pointer:
 | |
| 			named, ok := t.Elem().(*types.Named)
 | |
| 			if !ok {
 | |
| 				return
 | |
| 			}
 | |
| 			name = named.Obj()
 | |
| 		case *types.Named:
 | |
| 			name = t.Obj()
 | |
| 		default:
 | |
| 			return
 | |
| 		}
 | |
| 		if !isAccessibleFrom(name, qpos.info.Pkg) {
 | |
| 			return
 | |
| 		}
 | |
| 		res.types = append(res.types, &errorType{conc, name})
 | |
| 	})
 | |
| 	sort.Sort(membersByPosAndString(res.globals))
 | |
| 	sort.Sort(membersByPosAndString(res.consts))
 | |
| 	sort.Sort(sorterrorType(res.types))
 | |
| 	return res, nil
 | |
| }
 | |
| 
 | |
| // findVisibleErrs returns a mapping from each package-level variable of type "error" to nil.
 | |
| func findVisibleErrs(prog *ssa.Program, qpos *QueryPos) map[*ssa.Global]ssa.Value {
 | |
| 	globals := make(map[*ssa.Global]ssa.Value)
 | |
| 	for _, pkg := range prog.AllPackages() {
 | |
| 		for _, mem := range pkg.Members {
 | |
| 			gbl, ok := mem.(*ssa.Global)
 | |
| 			if !ok {
 | |
| 				continue
 | |
| 			}
 | |
| 			gbltype := gbl.Type()
 | |
| 			// globals are always pointers
 | |
| 			if !types.Identical(deref(gbltype), builtinErrorType) {
 | |
| 				continue
 | |
| 			}
 | |
| 			if !isAccessibleFrom(gbl.Object(), qpos.info.Pkg) {
 | |
| 				continue
 | |
| 			}
 | |
| 			globals[gbl] = nil
 | |
| 		}
 | |
| 	}
 | |
| 	return globals
 | |
| }
 | |
| 
 | |
| // findVisibleConsts returns a mapping from each package-level constant assignable to type "error", to nil.
 | |
| func findVisibleConsts(prog *ssa.Program, qpos *QueryPos) map[ssa.Const]*ssa.NamedConst {
 | |
| 	constants := make(map[ssa.Const]*ssa.NamedConst)
 | |
| 	for _, pkg := range prog.AllPackages() {
 | |
| 		for _, mem := range pkg.Members {
 | |
| 			obj, ok := mem.(*ssa.NamedConst)
 | |
| 			if !ok {
 | |
| 				continue
 | |
| 			}
 | |
| 			consttype := obj.Type()
 | |
| 			if !types.AssignableTo(consttype, builtinErrorType) {
 | |
| 				continue
 | |
| 			}
 | |
| 			if !isAccessibleFrom(obj.Object(), qpos.info.Pkg) {
 | |
| 				continue
 | |
| 			}
 | |
| 			constants[*obj.Value] = obj
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return constants
 | |
| }
 | |
| 
 | |
| type membersByPosAndString []ssa.Member
 | |
| 
 | |
| func (a membersByPosAndString) Len() int { return len(a) }
 | |
| func (a membersByPosAndString) Less(i, j int) bool {
 | |
| 	cmp := a[i].Pos() - a[j].Pos()
 | |
| 	return cmp < 0 || cmp == 0 && a[i].String() < a[j].String()
 | |
| }
 | |
| func (a membersByPosAndString) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
 | |
| 
 | |
| type sorterrorType []*errorType
 | |
| 
 | |
| func (a sorterrorType) Len() int { return len(a) }
 | |
| func (a sorterrorType) Less(i, j int) bool {
 | |
| 	cmp := a[i].obj.Pos() - a[j].obj.Pos()
 | |
| 	return cmp < 0 || cmp == 0 && a[i].typ.String() < a[j].typ.String()
 | |
| }
 | |
| func (a sorterrorType) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
 | |
| 
 | |
| type errorType struct {
 | |
| 	typ types.Type      // concrete type N or *N that implements error
 | |
| 	obj *types.TypeName // the named type N
 | |
| }
 | |
| 
 | |
| type whicherrsResult struct {
 | |
| 	qpos    *QueryPos
 | |
| 	errpos  token.Pos
 | |
| 	globals []ssa.Member
 | |
| 	consts  []ssa.Member
 | |
| 	types   []*errorType
 | |
| }
 | |
| 
 | |
| func (r *whicherrsResult) display(printf printfFunc) {
 | |
| 	if len(r.globals) > 0 {
 | |
| 		printf(r.qpos, "this error may point to these globals:")
 | |
| 		for _, g := range r.globals {
 | |
| 			printf(g.Pos(), "\t%s", g.RelString(r.qpos.info.Pkg))
 | |
| 		}
 | |
| 	}
 | |
| 	if len(r.consts) > 0 {
 | |
| 		printf(r.qpos, "this error may contain these constants:")
 | |
| 		for _, c := range r.consts {
 | |
| 			printf(c.Pos(), "\t%s", c.RelString(r.qpos.info.Pkg))
 | |
| 		}
 | |
| 	}
 | |
| 	if len(r.types) > 0 {
 | |
| 		printf(r.qpos, "this error may contain these dynamic types:")
 | |
| 		for _, t := range r.types {
 | |
| 			printf(t.obj.Pos(), "\t%s", r.qpos.TypeString(t.typ))
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (r *whicherrsResult) toSerial(res *serial.Result, fset *token.FileSet) {
 | |
| 	we := &serial.WhichErrs{}
 | |
| 	we.ErrPos = fset.Position(r.errpos).String()
 | |
| 	for _, g := range r.globals {
 | |
| 		we.Globals = append(we.Globals, fset.Position(g.Pos()).String())
 | |
| 	}
 | |
| 	for _, c := range r.consts {
 | |
| 		we.Constants = append(we.Constants, fset.Position(c.Pos()).String())
 | |
| 	}
 | |
| 	for _, t := range r.types {
 | |
| 		var et serial.WhichErrsType
 | |
| 		et.Type = r.qpos.TypeString(t.typ)
 | |
| 		et.Position = fset.Position(t.obj.Pos()).String()
 | |
| 		we.Types = append(we.Types, et)
 | |
| 	}
 | |
| 	res.WhichErrs = we
 | |
| }
 |