diff --git a/cmd/oracle/oracle.el b/cmd/oracle/oracle.el index 35cd8922..f90a5fd0 100644 --- a/cmd/oracle/oracle.el +++ b/cmd/oracle/oracle.el @@ -213,6 +213,12 @@ identifier." (interactive) (go-oracle--run "referrers")) +(defun go-oracle-whicherrs () + "Show globals, constants and types to which the selected +expression (of type 'error') may refer." + (interactive) + (go-oracle--run "whicherrs")) + ;; TODO(dominikh): better docstring (define-minor-mode go-oracle-mode "Oracle minor mode for go-mode diff --git a/oracle/oracle.go b/oracle/oracle.go index 55566c3f..f9c13e82 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -101,6 +101,7 @@ var modes = []*modeInfo{ {"callstack", needPTA | needPos, callstack}, {"peers", needPTA | needSSADebug | needPos, peers}, {"pointsto", needPTA | needSSADebug | needExactPos, pointsto}, + {"whicherrs", needPTA | needSSADebug | needExactPos, whicherrs}, // Type-based, modular analyses: {"definition", needPos, definition}, diff --git a/oracle/oracle_test.go b/oracle/oracle_test.go index 9c22a7d7..9775d6fc 100644 --- a/oracle/oracle_test.go +++ b/oracle/oracle_test.go @@ -213,6 +213,7 @@ func TestOracle(t *testing.T) { "testdata/src/main/pointsto.go", "testdata/src/main/reflection.go", "testdata/src/main/what.go", + "testdata/src/main/whicherrs.go", // JSON: // TODO(adonovan): most of these are very similar; combine them. "testdata/src/main/callgraph-json.go", diff --git a/oracle/serial/serial.go b/oracle/serial/serial.go index 67768722..6432b611 100644 --- a/oracle/serial/serial.go +++ b/oracle/serial/serial.go @@ -229,6 +229,21 @@ type PTAWarning struct { Message string `json:"message"` // warning message } +// A WhichErrs is the result of a 'whicherrs' query. +// It contains the position of the queried error and the possible globals, +// constants, and types it may point to. +type WhichErrs struct { + ErrPos string `json:"errpos,omitempty"` // location of queried error + Globals []string `json:"globals,omitempty"` // locations of globals + Constants []string `json:"constants,omitempty"` // locations of constants + Types []WhichErrsType `json:"types,omitempty"` // Types +} + +type WhichErrsType struct { + Type string `json:"type,omitempty"` + Position string `json:"position,omitempty"` +} + // A Result is the common result of any oracle query. // It contains a query-specific result element. // @@ -251,6 +266,7 @@ type Result struct { PointsTo []PointsTo `json:"pointsto,omitempty"` Referrers *Referrers `json:"referrers,omitempty"` What *What `json:"what,omitempty"` + WhichErrs *WhichErrs `json:"whicherrs,omitempty"` Warnings []PTAWarning `json:"warnings,omitempty"` // warnings from pointer analysis } diff --git a/oracle/testdata/src/main/whicherrs.go b/oracle/testdata/src/main/whicherrs.go new file mode 100644 index 00000000..27fe6b56 --- /dev/null +++ b/oracle/testdata/src/main/whicherrs.go @@ -0,0 +1,27 @@ +package main + +type errType string + +const constErr errType = "blah" + +func (et errType) Error() string { + return string(et) +} + +var errVar error = errType("foo") + +func genErr(i int) error { + switch i { + case 0: + return constErr + case 1: + return errVar + default: + return nil + } +} + +func main() { + err := genErr(0) // @whicherrs localerrs "err" + _ = err +} diff --git a/oracle/testdata/src/main/whicherrs.golden b/oracle/testdata/src/main/whicherrs.golden new file mode 100644 index 00000000..1118e0a8 --- /dev/null +++ b/oracle/testdata/src/main/whicherrs.golden @@ -0,0 +1,8 @@ +-------- @whicherrs localerrs -------- +this error may point to these globals: + errVar +this error may contain these constants: + constErr +this error may contain these dynamic types: + errType + diff --git a/oracle/whicherrs.go b/oracle/whicherrs.go new file mode 100644 index 00000000..9ffe2ac3 --- /dev/null +++ b/oracle/whicherrs.go @@ -0,0 +1,294 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package oracle + +import ( + "fmt" + "go/ast" + "go/token" + "sort" + + "golang.org/x/tools/astutil" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/ssa/ssautil" + "golang.org/x/tools/go/types" + "golang.org/x/tools/oracle/serial" +) + +var builtinErrorType = types.Universe.Lookup("error").Type() + +// whicherrs takes an position to an error and tries to find all types, constants +// and global value which a given error can point to and which can be checked from the +// scope where the error lives. +// In short, it returns a list of things that can be checked against in order to handle +// an error properly. +// +// TODO(dmorsing): figure out if fields in errors like *os.PathError.Err +// can be queried recursively somehow. +func whicherrs(o *Oracle, qpos *QueryPos) (queryResult, error) { + path, action := findInterestingNode(qpos.info, qpos.path) + if action != actionExpr { + return nil, fmt.Errorf("whicherrs wants an expression; got %s", + astutil.NodeDescription(qpos.path[0])) + } + var expr ast.Expr + var obj types.Object + switch n := path[0].(type) { + case *ast.ValueSpec: + // ambiguous ValueSpec containing multiple names + return nil, fmt.Errorf("multiple value specification") + case *ast.Ident: + obj = qpos.info.ObjectOf(n) + expr = n + case ast.Expr: + expr = n + default: + return nil, fmt.Errorf("unexpected AST for expr: %T", n) + } + + typ := qpos.info.TypeOf(expr) + if !types.Identical(typ, builtinErrorType) { + return nil, fmt.Errorf("selection is not an expression of type 'error'") + } + // Determine the ssa.Value for the expression. + var value ssa.Value + var err error + if obj != nil { + // def/ref of func/var object + value, _, err = ssaValueForIdent(o.prog, qpos.info, obj, path) + } else { + value, _, err = ssaValueForExpr(o.prog, qpos.info, path) + } + if err != nil { + return nil, err // e.g. trivially dead code + } + buildSSA(o) + + globals := findVisibleErrs(o.prog, qpos) + constants := findVisibleConsts(o.prog, qpos) + + res := &whicherrsResult{ + qpos: qpos, + errpos: expr.Pos(), + } + + // Find the instruction which initialized the + // global error. If more than one instruction has stored to the global + // remove the global from the set of values that we want to query. + allFuncs := ssautil.AllFunctions(o.prog) + for fn := range allFuncs { + for _, b := range fn.Blocks { + for _, instr := range b.Instrs { + store, ok := instr.(*ssa.Store) + if !ok { + continue + } + gval, ok := store.Addr.(*ssa.Global) + if !ok { + continue + } + gbl, ok := globals[gval] + if !ok { + continue + } + // we already found a store to this global + // The normal error define is just one store in the init + // so we just remove this global from the set we want to query + if gbl != nil { + delete(globals, gval) + } + globals[gval] = store.Val + } + } + } + + o.ptaConfig.AddQuery(value) + for _, v := range globals { + o.ptaConfig.AddQuery(v) + } + + ptares := ptrAnalysis(o) + valueptr := ptares.Queries[value] + for g, v := range globals { + ptr, ok := ptares.Queries[v] + if !ok { + continue + } + if !ptr.MayAlias(valueptr) { + continue + } + res.globals = append(res.globals, g) + } + pts := valueptr.PointsTo() + dedup := make(map[*ssa.NamedConst]bool) + for _, label := range pts.Labels() { + // These values are either MakeInterfaces or reflect + // generated interfaces. For the purposes of this + // analysis, we don't care about reflect generated ones + makeiface, ok := label.Value().(*ssa.MakeInterface) + if !ok { + continue + } + constval, ok := makeiface.X.(*ssa.Const) + if !ok { + continue + } + c := constants[*constval] + if c != nil && !dedup[c] { + dedup[c] = true + res.consts = append(res.consts, c) + } + } + concs := pts.DynamicTypes() + concs.Iterate(func(conc types.Type, _ interface{}) { + // go/types is a bit annoying here. + // We want to find all the types that we can + // typeswitch or assert to. This means finding out + // if the type pointed to can be seen by us. + // + // For the purposes of this analysis, the type is always + // either a Named type or a pointer to one. + // There are cases where error can be implemented + // by unnamed types, but in that case, we can't assert to + // it, so we don't care about it for this analysis. + var name *types.TypeName + switch t := conc.(type) { + case *types.Pointer: + named, ok := t.Elem().(*types.Named) + if !ok { + return + } + name = named.Obj() + case *types.Named: + name = t.Obj() + default: + return + } + if !isAccessibleFrom(name, qpos.info.Pkg) { + return + } + res.types = append(res.types, &errorType{conc, name}) + }) + sort.Sort(membersByPosAndString(res.globals)) + sort.Sort(membersByPosAndString(res.consts)) + sort.Sort(sorterrorType(res.types)) + return res, nil +} + +// findVisibleErrs returns a mapping from each package-level variable of type "error" to nil. +func findVisibleErrs(prog *ssa.Program, qpos *QueryPos) map[*ssa.Global]ssa.Value { + globals := make(map[*ssa.Global]ssa.Value) + for _, pkg := range prog.AllPackages() { + for _, mem := range pkg.Members { + gbl, ok := mem.(*ssa.Global) + if !ok { + continue + } + gbltype := gbl.Type() + // globals are always pointers + if !types.Identical(deref(gbltype), builtinErrorType) { + continue + } + if !isAccessibleFrom(gbl.Object(), qpos.info.Pkg) { + continue + } + globals[gbl] = nil + } + } + return globals +} + +// findVisibleConsts returns a mapping from each package-level constant assignable to type "error", to nil. +func findVisibleConsts(prog *ssa.Program, qpos *QueryPos) map[ssa.Const]*ssa.NamedConst { + constants := make(map[ssa.Const]*ssa.NamedConst) + for _, pkg := range prog.AllPackages() { + for _, mem := range pkg.Members { + obj, ok := mem.(*ssa.NamedConst) + if !ok { + continue + } + consttype := obj.Type() + if !types.AssignableTo(consttype, builtinErrorType) { + continue + } + if !isAccessibleFrom(obj.Object(), qpos.info.Pkg) { + continue + } + constants[*obj.Value] = obj + } + } + + return constants +} + +type membersByPosAndString []ssa.Member + +func (a membersByPosAndString) Len() int { return len(a) } +func (a membersByPosAndString) Less(i, j int) bool { + cmp := a[i].Pos() - a[j].Pos() + return cmp < 0 || cmp == 0 && a[i].String() < a[j].String() +} +func (a membersByPosAndString) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +type sorterrorType []*errorType + +func (a sorterrorType) Len() int { return len(a) } +func (a sorterrorType) Less(i, j int) bool { + cmp := a[i].obj.Pos() - a[j].obj.Pos() + return cmp < 0 || cmp == 0 && a[i].typ.String() < a[j].typ.String() +} +func (a sorterrorType) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +type errorType struct { + typ types.Type // concrete type N or *N that implements error + obj *types.TypeName // the named type N +} + +type whicherrsResult struct { + qpos *QueryPos + errpos token.Pos + globals []ssa.Member + consts []ssa.Member + types []*errorType +} + +func (r *whicherrsResult) display(printf printfFunc) { + if len(r.globals) > 0 { + printf(r.qpos, "this error may point to these globals:") + for _, g := range r.globals { + printf(g.Pos(), "\t%s", g.RelString(r.qpos.info.Pkg)) + } + } + if len(r.consts) > 0 { + printf(r.qpos, "this error may contain these constants:") + for _, c := range r.consts { + printf(c.Pos(), "\t%s", c.RelString(r.qpos.info.Pkg)) + } + } + if len(r.types) > 0 { + printf(r.qpos, "this error may contain these dynamic types:") + for _, t := range r.types { + printf(t.obj.Pos(), "\t%s", r.qpos.TypeString(t.typ)) + } + } +} + +func (r *whicherrsResult) toSerial(res *serial.Result, fset *token.FileSet) { + we := &serial.WhichErrs{} + we.ErrPos = fset.Position(r.errpos).String() + for _, g := range r.globals { + we.Globals = append(we.Globals, fset.Position(g.Pos()).String()) + } + for _, c := range r.consts { + we.Constants = append(we.Constants, fset.Position(c.Pos()).String()) + } + for _, t := range r.types { + var et serial.WhichErrsType + et.Type = r.qpos.TypeString(t.typ) + et.Position = fset.Position(t.obj.Pos()).String() + we.Types = append(we.Types, et) + } + res.WhichErrs = we +}