From 25a0cc4bfd406533eedfbd18180ffe23a10d8c3a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 23 Sep 2013 15:02:18 -0400 Subject: [PATCH] go.tools/oracle: refactor Oracle API to allow repeated queries on same scope. The existing standalone Query function builds an importer, ssa.Program, oracle, and query position, executes the query and returns the result. For clients (such as Frederik Zipp's web-based github.com/fzipp/pythia tool) that wish to load the program once and make several queries, we now expose these as separate operations too. Here's a client, in pseudocode: o := oracle.New(...) for ... { qpos := o.ParseQueryPos(...) res := o.Query(mode, qpos) print result } NB: this is a slight deoptimisation in the one-shot case since we have to build the entire SSA program with debug info, not just the query package, since we now don't know the query package at that time. The 'exact' param to ParseQueryPos needs more thought since its ideal value is a function of the query mode. This will do for now. Details: - expose Oracle type, New() func and Query() method. - expose QueryPos type and ParseQueryPos func. - improved package doc comment. - un-exposed the "needs" bits. - added test. R=crawshaw CC=frederik.zipp, golang-dev https://golang.org/cl/13810043 --- oracle/callees.go | 10 +- oracle/callers.go | 14 +- oracle/callgraph.go | 5 +- oracle/callstack.go | 18 +- oracle/describe.go | 74 ++++--- oracle/freevars.go | 22 +- oracle/implements.go | 4 +- oracle/oracle.go | 267 +++++++++++++++-------- oracle/oracle_test.go | 51 +++++ oracle/peers.go | 10 +- oracle/referrers.go | 10 +- oracle/testdata/src/main/calls.golden | 1 - oracle/testdata/src/main/freevars.golden | 6 +- oracle/testdata/src/main/multi.go | 13 ++ 14 files changed, 332 insertions(+), 173 deletions(-) create mode 100644 oracle/testdata/src/main/multi.go diff --git a/oracle/callees.go b/oracle/callees.go index afd8d83f..1e0d65e1 100644 --- a/oracle/callees.go +++ b/oracle/callees.go @@ -20,28 +20,28 @@ import ( // // TODO(adonovan): if a callee is a wrapper, show the callee's callee. // -func callees(o *oracle) (queryResult, error) { +func callees(o *Oracle, qpos *QueryPos) (queryResult, error) { // Determine the enclosing call for the specified position. var call *ast.CallExpr - for _, n := range o.queryPath { + for _, n := range qpos.path { if call, _ = n.(*ast.CallExpr); call != nil { break } } if call == nil { - return nil, o.errorf(o.queryPath[0], "there is no function call here") + return nil, o.errorf(qpos.path[0], "there is no function call here") } // TODO(adonovan): issue an error if the call is "too far // away" from the current selection, as this most likely is // not what the user intended. // Reject type conversions. - if o.queryPkgInfo.IsType(call.Fun) { + if qpos.info.IsType(call.Fun) { return nil, o.errorf(call, "this is a type conversion, not a function call") } // Reject calls to built-ins. - if b, ok := o.queryPkgInfo.TypeOf(call.Fun).(*types.Builtin); ok { + if b, ok := qpos.info.TypeOf(call.Fun).(*types.Builtin); ok { return nil, o.errorf(call, "this is a call to the built-in '%s' operator", b.Name()) } diff --git a/oracle/callers.go b/oracle/callers.go index 37549978..01f827be 100644 --- a/oracle/callers.go +++ b/oracle/callers.go @@ -17,20 +17,20 @@ import ( // // TODO(adonovan): if a caller is a wrapper, show the caller's caller. // -func callers(o *oracle) (queryResult, error) { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) +func callers(o *Oracle, qpos *QueryPos) (queryResult, error) { + pkg := o.prog.Package(qpos.info.Pkg) if pkg == nil { - return nil, o.errorf(o.queryPath[0], "no SSA package") + return nil, o.errorf(qpos.path[0], "no SSA package") } - if !ssa.HasEnclosingFunction(pkg, o.queryPath) { - return nil, o.errorf(o.queryPath[0], "this position is not inside a function") + if !ssa.HasEnclosingFunction(pkg, qpos.path) { + return nil, o.errorf(qpos.path[0], "this position is not inside a function") } buildSSA(o) - target := ssa.EnclosingFunction(pkg, o.queryPath) + target := ssa.EnclosingFunction(pkg, qpos.path) if target == nil { - return nil, o.errorf(o.queryPath[0], "no SSA function built for this location (dead code?)") + return nil, o.errorf(qpos.path[0], "no SSA function built for this location (dead code?)") } // Run the pointer analysis, recording each diff --git a/oracle/callgraph.go b/oracle/callgraph.go index 6ebe1912..c8cdd680 100644 --- a/oracle/callgraph.go +++ b/oracle/callgraph.go @@ -25,7 +25,7 @@ import ( // // TODO(adonovan): elide nodes for synthetic functions? // -func callgraph(o *oracle) (queryResult, error) { +func callgraph(o *Oracle, _ *QueryPos) (queryResult, error) { buildSSA(o) // Run the pointer analysis and build the complete callgraph. @@ -35,6 +35,7 @@ func callgraph(o *oracle) (queryResult, error) { // Assign (preorder) numbers to all the callgraph nodes. // TODO(adonovan): the callgraph API should do this for us. + // (Actually, it does have unique numbers under the hood.) numbering := make(map[pointer.CallGraphNode]int) var number func(cgn pointer.CallGraphNode) number = func(cgn pointer.CallGraphNode) { @@ -70,6 +71,8 @@ Some nodes may appear multiple times due to context-sensitive treatment of some calls. `) + // TODO(adonovan): compute the numbers as we print; right now + // it depends on map iteration so it's arbitrary,which is ugly. seen := make(map[pointer.CallGraphNode]bool) var print func(cgn pointer.CallGraphNode, indent int) print = func(cgn pointer.CallGraphNode, indent int) { diff --git a/oracle/callstack.go b/oracle/callstack.go index 16516b6d..6a5f4323 100644 --- a/oracle/callstack.go +++ b/oracle/callstack.go @@ -22,21 +22,21 @@ import ( // TODO(adonovan): permit user to specify a starting point other than // the analysis root. // -func callstack(o *oracle) (queryResult, error) { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) +func callstack(o *Oracle, qpos *QueryPos) (queryResult, error) { + pkg := o.prog.Package(qpos.info.Pkg) if pkg == nil { - return nil, o.errorf(o.queryPath[0], "no SSA package") + return nil, o.errorf(qpos.path[0], "no SSA package") } - if !ssa.HasEnclosingFunction(pkg, o.queryPath) { - return nil, o.errorf(o.queryPath[0], "this position is not inside a function") + if !ssa.HasEnclosingFunction(pkg, qpos.path) { + return nil, o.errorf(qpos.path[0], "this position is not inside a function") } buildSSA(o) - target := ssa.EnclosingFunction(pkg, o.queryPath) + target := ssa.EnclosingFunction(pkg, qpos.path) if target == nil { - return nil, o.errorf(o.queryPath[0], + return nil, o.errorf(qpos.path[0], "no SSA function built for this location (dead code?)") } @@ -74,19 +74,21 @@ func callstack(o *oracle) (queryResult, error) { } return &callstackResult{ + qpos: qpos, target: target, callstack: callstack, }, nil } type callstackResult struct { + qpos *QueryPos target *ssa.Function callstack []pointer.CallSite } func (r *callstackResult) display(printf printfFunc) { if r.callstack != nil { - printf(false, "Found a call path from root to %s", r.target) + printf(r.qpos, "Found a call path from root to %s", r.target) printf(r.target, "%s", r.target) for _, site := range r.callstack { printf(site, "%s from %s", site.Description(), site.Caller().Func()) diff --git a/oracle/describe.go b/oracle/describe.go index 143c7e4c..2c30823c 100644 --- a/oracle/describe.go +++ b/oracle/describe.go @@ -33,25 +33,25 @@ import ( // // All printed sets are sorted to ensure determinism. // -func describe(o *oracle) (queryResult, error) { +func describe(o *Oracle, qpos *QueryPos) (queryResult, error) { if false { // debugging - o.fprintf(os.Stderr, o.queryPath[0], "you selected: %s %s", - importer.NodeDescription(o.queryPath[0]), pathToString2(o.queryPath)) + o.fprintf(os.Stderr, qpos.path[0], "you selected: %s %s", + importer.NodeDescription(qpos.path[0]), pathToString2(qpos.path)) } - path, action := findInterestingNode(o.queryPkgInfo, o.queryPath) + path, action := findInterestingNode(qpos.info, qpos.path) switch action { case actionExpr: - return describeValue(o, path) + return describeValue(o, qpos, path) case actionType: - return describeType(o, path) + return describeType(o, qpos, path) case actionPackage: - return describePackage(o, path) + return describePackage(o, qpos, path) case actionStmt: - return describeStmt(o, path) + return describeStmt(o, qpos, path) case actionUnknown: return &describeUnknownResult{path[0]}, nil @@ -305,11 +305,11 @@ func findInterestingNode(pkginfo *importer.PackageInfo, path []ast.Node) ([]ast. // to the root of the AST is path. It may return a nil Value without // an error to indicate the pointer analysis is not appropriate. // -func ssaValueForIdent(o *oracle, obj types.Object, path []ast.Node) (ssa.Value, error) { +func ssaValueForIdent(prog *ssa.Program, qinfo *importer.PackageInfo, obj types.Object, path []ast.Node) (ssa.Value, error) { if obj, ok := obj.(*types.Var); ok { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) + pkg := prog.Package(qinfo.Pkg) pkg.Build() - if v := o.prog.VarValue(obj, pkg, path); v != nil { + if v := prog.VarValue(obj, pkg, path); v != nil { // Don't run pointer analysis on a ref to a const expression. if _, ok := v.(*ssa.Const); ok { v = nil @@ -328,8 +328,8 @@ func ssaValueForIdent(o *oracle, obj types.Object, path []ast.Node) (ssa.Value, // return a nil Value without an error to indicate the pointer // analysis is not appropriate. // -func ssaValueForExpr(o *oracle, path []ast.Node) (ssa.Value, error) { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) +func ssaValueForExpr(prog *ssa.Program, qinfo *importer.PackageInfo, path []ast.Node) (ssa.Value, error) { + pkg := prog.Package(qinfo.Pkg) pkg.SetDebugMode(true) pkg.Build() @@ -345,7 +345,7 @@ func ssaValueForExpr(o *oracle, path []ast.Node) (ssa.Value, error) { return nil, fmt.Errorf("can't locate SSA Value for expression in %s", fn) } -func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { +func describeValue(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeValueResult, error) { var expr ast.Expr var obj types.Object switch n := path[0].(type) { @@ -353,7 +353,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { // ambiguous ValueSpec containing multiple names return nil, o.errorf(n, "multiple value specification") case *ast.Ident: - obj = o.queryPkgInfo.ObjectOf(n) + obj = qpos.info.ObjectOf(n) expr = n case ast.Expr: expr = n @@ -362,8 +362,8 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { return nil, o.errorf(n, "unexpected AST for expr: %T", n) } - typ := o.queryPkgInfo.TypeOf(expr) - constVal := o.queryPkgInfo.ValueOf(expr) + typ := qpos.info.TypeOf(expr) + constVal := qpos.info.ValueOf(expr) // From this point on, we cannot fail with an error. // Failure to run the pointer analysis will be reported later. @@ -386,11 +386,11 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { var value ssa.Value if obj != nil { // def/ref of func/var/const object - value, ptaErr = ssaValueForIdent(o, obj, path) + value, ptaErr = ssaValueForIdent(o.prog, qpos.info, obj, path) } else { // any other expression - if o.queryPkgInfo.ValueOf(path[0].(ast.Expr)) == nil { // non-constant? - value, ptaErr = ssaValueForExpr(o, path) + if qpos.info.ValueOf(path[0].(ast.Expr)) == nil { // non-constant? + value, ptaErr = ssaValueForExpr(o.prog, qpos.info, path) } } if value != nil { @@ -404,6 +404,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { } return &describeValueResult{ + qpos: qpos, expr: expr, typ: typ, constVal: constVal, @@ -414,7 +415,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { } // describePointer runs the pointer analysis of the selected SSA value. -func describePointer(o *oracle, v ssa.Value, indirect bool) (ptrs []pointerResult, err error) { +func describePointer(o *Oracle, v ssa.Value, indirect bool) (ptrs []pointerResult, err error) { buildSSA(o) // TODO(adonovan): don't run indirect pointer analysis on non-ptr-ptrlike types. @@ -454,6 +455,7 @@ type pointerResult struct { } type describeValueResult struct { + qpos *QueryPos expr ast.Expr // query node typ types.Type // type of expression constVal exact.Value // value of expression, if constant @@ -524,7 +526,7 @@ func (r *describeValueResult) display(printf printfFunc) { // reflect.Value expression. if len(r.ptrs) > 0 { - printf(false, "this %s may contain these dynamic types:", r.typ) + printf(r.qpos, "this %s may contain these dynamic types:", r.typ) for _, ptr := range r.ptrs { var obj types.Object if nt, ok := deref(ptr.typ).(*types.Named); ok { @@ -538,15 +540,15 @@ func (r *describeValueResult) display(printf printfFunc) { } } } else { - printf(false, "this %s cannot contain any dynamic types.", r.typ) + printf(r.qpos, "this %s cannot contain any dynamic types.", r.typ) } } else { // Show labels for other expressions. if ptr := r.ptrs[0]; len(ptr.labels) > 0 { - printf(false, "value may point to these labels:") + printf(r.qpos, "value may point to these labels:") printLabels(printf, ptr.labels, "\t") } else { - printf(false, "value cannot point to anything.") + printf(r.qpos, "value cannot point to anything.") } } } @@ -622,12 +624,12 @@ func printLabels(printf printfFunc, labels []*pointer.Label, prefix string) { // ---- TYPE ------------------------------------------------------------ -func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) { +func describeType(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeTypeResult, error) { var description string var t types.Type switch n := path[0].(type) { case *ast.Ident: - t = o.queryPkgInfo.TypeOf(n) + t = qpos.info.TypeOf(n) switch t := t.(type) { case *types.Basic: description = "reference to built-in type " + t.String() @@ -642,7 +644,7 @@ func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) { } case ast.Expr: - t = o.queryPkgInfo.TypeOf(n) + t = qpos.info.TypeOf(n) description = "type " + t.String() default: @@ -654,7 +656,7 @@ func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) { node: path[0], description: description, typ: t, - methods: accessibleMethods(t, o.queryPkgInfo.Pkg), + methods: accessibleMethods(t, qpos.info.Pkg), }, nil } @@ -708,7 +710,7 @@ func (r *describeTypeResult) toJSON(res *json.Result, fset *token.FileSet) { // ---- PACKAGE ------------------------------------------------------------ -func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error) { +func describePackage(o *Oracle, qpos *QueryPos, path []ast.Node) (*describePackageResult, error) { var description string var pkg *types.Package switch n := path[0].(type) { @@ -724,12 +726,12 @@ func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error) case *ast.Ident: if _, isDef := path[1].(*ast.File); isDef { // e.g. package id - pkg = o.queryPkgInfo.Pkg + pkg = qpos.info.Pkg description = fmt.Sprintf("definition of package %q", pkg.Path()) } else { // e.g. import id // or id.F() - pkg = o.queryPkgInfo.ObjectOf(n).Pkg() + pkg = qpos.info.ObjectOf(n).Pkg() description = fmt.Sprintf("reference to package %q", pkg.Path()) } @@ -744,11 +746,11 @@ func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error) // Enumerate the accessible package members // in lexicographic order. for _, name := range pkg.Scope().Names() { - if pkg == o.queryPkgInfo.Pkg || ast.IsExported(name) { + if pkg == qpos.info.Pkg || ast.IsExported(name) { mem := pkg.Scope().Lookup(name) var methods []*types.Selection if mem, ok := mem.(*types.TypeName); ok { - methods = accessibleMethods(mem.Type(), o.queryPkgInfo.Pkg) + methods = accessibleMethods(mem.Type(), qpos.info.Pkg) } members = append(members, &describeMember{ mem, @@ -878,11 +880,11 @@ func tokenOf(o types.Object) string { // ---- STATEMENT ------------------------------------------------------------ -func describeStmt(o *oracle, path []ast.Node) (*describeStmtResult, error) { +func describeStmt(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeStmtResult, error) { var description string switch n := path[0].(type) { case *ast.Ident: - if o.queryPkgInfo.ObjectOf(n).Pos() == n.Pos() { + if qpos.info.ObjectOf(n).Pos() == n.Pos() { description = "labelled statement" } else { description = "reference to labelled statement" diff --git a/oracle/freevars.go b/oracle/freevars.go index 07973f2a..5e5d7548 100644 --- a/oracle/freevars.go +++ b/oracle/freevars.go @@ -26,9 +26,9 @@ import ( // these might be interesting. Perhaps group the results into three // bands. // -func freevars(o *oracle) (queryResult, error) { - file := o.queryPath[len(o.queryPath)-1] // the enclosing file - fileScope := o.queryPkgInfo.Scopes[file] +func freevars(o *Oracle, qpos *QueryPos) (queryResult, error) { + file := qpos.path[len(qpos.path)-1] // the enclosing file + fileScope := qpos.info.Scopes[file] pkgScope := fileScope.Parent() // The id and sel functions return non-nil if they denote an @@ -49,7 +49,7 @@ func freevars(o *oracle) (queryResult, error) { } id = func(n *ast.Ident) types.Object { - obj := o.queryPkgInfo.ObjectOf(n) + obj := qpos.info.ObjectOf(n) if obj == nil { return nil // TODO(adonovan): fix: this fails for *types.Label. panic(o.errorf(n, "no types.Object for ast.Ident")) @@ -70,7 +70,7 @@ func freevars(o *oracle) (queryResult, error) { if scope == fileScope || scope == pkgScope { return nil // defined at file or package scope } - if o.startPos <= obj.Pos() && obj.Pos() <= o.endPos { + if qpos.start <= obj.Pos() && obj.Pos() <= qpos.end { return nil // defined within selection => not free } return obj @@ -82,7 +82,7 @@ func freevars(o *oracle) (queryResult, error) { refsMap := make(map[string]freevarsRef) // Visit all the identifiers in the selected ASTs. - ast.Inspect(o.queryPath[0], func(n ast.Node) bool { + ast.Inspect(qpos.path[0], func(n ast.Node) bool { if n == nil { return true // popping DFS stack } @@ -90,7 +90,7 @@ func freevars(o *oracle) (queryResult, error) { // Is this node contained within the selection? // (freevars permits inexact selections, // like two stmts in a block.) - if o.startPos <= n.Pos() && n.End() <= o.endPos { + if qpos.start <= n.Pos() && n.End() <= qpos.end { var obj types.Object var prune bool switch n := n.(type) { @@ -119,7 +119,7 @@ func freevars(o *oracle) (queryResult, error) { panic(obj) } - typ := o.queryPkgInfo.TypeOf(n.(ast.Expr)) + typ := qpos.info.TypeOf(n.(ast.Expr)) ref := freevarsRef{kind, o.printNode(n), typ, obj} refsMap[ref.ref] = ref @@ -139,12 +139,14 @@ func freevars(o *oracle) (queryResult, error) { sort.Sort(byRef(refs)) return &freevarsResult{ + qpos: qpos, fset: o.prog.Fset, refs: refs, }, nil } type freevarsResult struct { + qpos *QueryPos fset *token.FileSet refs []freevarsRef } @@ -158,9 +160,9 @@ type freevarsRef struct { func (r *freevarsResult) display(printf printfFunc) { if len(r.refs) == 0 { - printf(false, "No free identifers.") + printf(r.qpos, "No free identifiers.") } else { - printf(false, "Free identifers:") + printf(r.qpos, "Free identifiers:") for _, ref := range r.refs { printf(ref.obj, "%s %s %s", ref.kind, ref.ref, ref.typ) } diff --git a/oracle/implements.go b/oracle/implements.go index b31617a2..d26cd1ec 100644 --- a/oracle/implements.go +++ b/oracle/implements.go @@ -30,8 +30,8 @@ import ( // actually occur, with examples? (NB: this is not a conservative // answer due to ChangeInterface, i.e. subtyping among interfaces.) // -func implements(o *oracle) (queryResult, error) { - pkg := o.queryPkgInfo.Pkg +func implements(o *Oracle, qpos *QueryPos) (queryResult, error) { + pkg := qpos.info.Pkg // Compute set of named interface/concrete types at package level. var interfaces, concretes []*types.Named diff --git a/oracle/oracle.go b/oracle/oracle.go index e462ed54..b53095cc 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -2,6 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package oracle contains the implementation of the oracle tool whose +// command-line is provided by code.google.com/p/go.tools/cmd/oracle. +// +// http://golang.org/s/oracle-design +// http://golang.org/s/oracle-user-manual +// package oracle // This file defines oracle.Query, the entry point for the oracle tool. @@ -33,16 +39,12 @@ import ( "code.google.com/p/go.tools/ssa" ) -type oracle struct { +// An Oracle holds the program state required for one or more queries. +type Oracle struct { out io.Writer // standard output prog *ssa.Program // the SSA program [only populated if need&SSA] config pointer.Config // pointer analysis configuration - // need&(Pos|ExactPos): - startPos, endPos token.Pos // source extent of query - queryPkgInfo *importer.PackageInfo // type info for the queried package - queryPath []ast.Node // AST path from query node to root of ast.File - // need&AllTypeInfo typeInfo map[*types.Package]*importer.PackageInfo // type info for all ASTs in the program @@ -53,30 +55,42 @@ type oracle struct { // // Typed ASTs for the whole program are always constructed // transiently; they are retained only for the queried package unless -// AllTypeInfo is set. +// needAllTypeInfo is set. const ( - Pos = 1 << iota // needs a position - ExactPos // needs an exact AST selection; implies Pos - AllTypeInfo // needs to retain type info for all ASTs in the program - SSA // needs ssa.Packages for whole program - PTA = SSA // needs pointer analysis + needPos = 1 << iota // needs a position + needExactPos // needs an exact AST selection; implies needPos + needAllTypeInfo // needs to retain type info for all ASTs in the program + needSSA // needs ssa.Packages for whole program + needSSADebug // needs debug info for ssa.Packages + needPTA = needSSA // needs pointer analysis + needAll = -1 // needs everything (e.g. a sequence of queries) ) type modeInfo struct { + name string needs int - impl func(*oracle) (queryResult, error) + impl func(*Oracle, *QueryPos) (queryResult, error) } -var modes = map[string]modeInfo{ - "callees": modeInfo{PTA | ExactPos, callees}, - "callers": modeInfo{PTA | Pos, callers}, - "callgraph": modeInfo{PTA, callgraph}, - "callstack": modeInfo{PTA | Pos, callstack}, - "describe": modeInfo{PTA | ExactPos, describe}, - "freevars": modeInfo{Pos, freevars}, - "implements": modeInfo{Pos, implements}, - "peers": modeInfo{PTA | Pos, peers}, - "referrers": modeInfo{AllTypeInfo | Pos, referrers}, +var modes = []*modeInfo{ + {"callees", needPTA | needExactPos, callees}, + {"callers", needPTA | needPos, callers}, + {"callgraph", needPTA, callgraph}, + {"callstack", needPTA | needPos, callstack}, + {"describe", needPTA | needSSADebug | needExactPos, describe}, + {"freevars", needPos, freevars}, + {"implements", needPos, implements}, + {"peers", needPTA | needSSADebug | needPos, peers}, + {"referrers", needAllTypeInfo | needPos, referrers}, +} + +func findMode(mode string) *modeInfo { + for _, m := range modes { + if m.name == mode { + return m + } + } + return nil } type printfFunc func(pos interface{}, format string, args ...interface{}) @@ -93,6 +107,17 @@ type warning struct { args []interface{} } +// A QueryPos represents the position provided as input to a query: +// a textual extent in the program's source code, the AST node it +// corresponds to, and the package to which it belongs. +// Instances are created by ParseQueryPos. +// +type QueryPos struct { + start, end token.Pos // source extent of query + info *importer.PackageInfo // type info for the queried package + path []ast.Node // AST path from query node to root of ast.File +} + // A Result encapsulates the result of an oracle.Query. // // Result instances implement the json.Marshaler interface, i.e. they @@ -118,41 +143,91 @@ func (res *Result) MarshalJSON() ([]byte, error) { return encjson.Marshal(resj) } -// Query runs the oracle. +// Query runs a single oracle query. +// // args specify the main package in importer.CreatePackageFromArgs syntax. // mode is the query mode ("callers", etc). -// pos is the selection in parseQueryPos() syntax. // ptalog is the (optional) pointer-analysis log file. -// buildContext is the optional configuration for locating packages. +// buildContext is the go/build configuration for locating packages. +// +// Clients that intend to perform multiple queries against the same +// analysis scope should use this pattern instead: +// +// imp := importer.New(&importer.Config{Build: buildContext}) +// o, err := oracle.New(imp, args, nil) +// if err != nil { ... } +// for ... { +// qpos, err := oracle.ParseQueryPos(imp, pos, needExact) +// if err != nil { ... } +// +// res, err := o.Query(mode, qpos) +// if err != nil { ... } +// +// // use res +// } +// +// TODO(adonovan): the ideal 'needsExact' parameter for ParseQueryPos +// depends on the query mode; how should we expose this? // func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *build.Context) (*Result, error) { - minfo, ok := modes[mode] - if !ok { + minfo := findMode(mode) + if minfo == nil { return nil, fmt.Errorf("invalid mode type: %q", mode) } imp := importer.New(&importer.Config{Build: buildContext}) - o := &oracle{ + o, err := New(imp, args, ptalog) + if err != nil { + return nil, err + } + + // Phase timing diagnostics. + // TODO(adonovan): needs more work. + // if false { + // defer func() { + // fmt.Println() + // for name, duration := range o.timers { + // fmt.Printf("# %-30s %s\n", name, duration) + // } + // }() + // } + + var qpos *QueryPos + if minfo.needs&(needPos|needExactPos) != 0 { + var err error + qpos, err = ParseQueryPos(imp, pos, minfo.needs&needExactPos != 0) + if err != nil { + return nil, err + } + } + + // SSA is built and we have the QueryPos. + // Release the other ASTs and type info to the GC. + imp = nil + + return o.query(minfo, qpos) +} + +// New constructs a new Oracle that can be used for a sequence of queries. +// +// imp will be used to load source code for imported packages. +// It must not yet have loaded any packages. +// +// args specify the main package in importer.CreatePackageFromArgs syntax. +// +// ptalog is the (optional) pointer-analysis log file. +// +func New(imp *importer.Importer, args []string, ptalog io.Writer) (*Oracle, error) { + return newOracle(imp, args, ptalog, needAll) +} + +func newOracle(imp *importer.Importer, args []string, ptalog io.Writer, needs int) (*Oracle, error) { + o := &Oracle{ prog: ssa.NewProgram(imp.Fset, 0), timers: make(map[string]time.Duration), } o.config.Log = ptalog - var res Result - o.config.Warn = func(pos token.Pos, format string, args ...interface{}) { - res.warnings = append(res.warnings, warning{pos, format, args}) - } - - // Phase timing diagnostics. - if false { - defer func() { - fmt.Println() - for name, duration := range o.timers { - fmt.Printf("# %-30s %s\n", name, duration) - } - }() - } - // Load/parse/type-check program from args. start := time.Now() initialPkgInfos, args, err := imp.LoadInitialPackages(args) @@ -165,7 +240,7 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil o.timers["load/parse/type"] = time.Since(start) // Retain type info for all ASTs in the program. - if minfo.needs&AllTypeInfo != 0 { + if needs&needAllTypeInfo != 0 { m := make(map[*types.Package]*importer.PackageInfo) for _, p := range imp.AllPackages() { m[p.Pkg] = p @@ -173,32 +248,13 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil o.typeInfo = m } - // Parse the source query position. - if minfo.needs&(Pos|ExactPos) != 0 { - var err error - o.startPos, o.endPos, err = parseQueryPos(o.prog.Fset, pos) - if err != nil { - return nil, err - } - - var exact bool - o.queryPkgInfo, o.queryPath, exact = imp.PathEnclosingInterval(o.startPos, o.endPos) - if o.queryPath == nil { - return nil, o.errorf(false, "no syntax here") - } - if minfo.needs&ExactPos != 0 && !exact { - return nil, o.errorf(o.queryPath[0], "ambiguous selection within %s", - importer.NodeDescription(o.queryPath[0])) - } - } - // Create SSA package for the initial package and its dependencies. - if minfo.needs&SSA != 0 { + if needs&needSSA != 0 { start = time.Now() // Create SSA packages. if err := o.prog.CreatePackages(imp); err != nil { - return nil, o.errorf(false, "%s", err) + return nil, o.errorf(nil, "%s", err) } // Initial packages (specified on command line) @@ -211,34 +267,66 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil // should build a single synthetic testmain package, // not synthetic main functions to many packages. if initialPkg.CreateTestMainFunction() == nil { - return nil, o.errorf(false, "analysis scope has no main() entry points") + return nil, o.errorf(nil, "analysis scope has no main() entry points") } } o.config.Mains = append(o.config.Mains, initialPkg) } - // Query package. - if o.queryPkgInfo != nil { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) - pkg.SetDebugMode(true) - pkg.Build() + if needs&needSSADebug != 0 { + for _, pkg := range o.prog.AllPackages() { + pkg.SetDebugMode(true) + } } o.timers["SSA-create"] = time.Since(start) } - // SSA is built and we have query{Path,PkgInfo}. - // Release the other ASTs and type info to the GC. - imp = nil + return o, nil +} - res.q, err = minfo.impl(o) +// Query runs the query of the specified mode and selection. +func (o *Oracle) Query(mode string, qpos *QueryPos) (*Result, error) { + minfo := findMode(mode) + if minfo == nil { + return nil, fmt.Errorf("invalid mode type: %q", mode) + } + return o.query(minfo, qpos) +} + +func (o *Oracle) query(minfo *modeInfo, qpos *QueryPos) (*Result, error) { + res := &Result{ + mode: minfo.name, + fset: o.prog.Fset, + fprintf: o.fprintf, // captures o.prog, o.{start,end}Pos for later printing + } + o.config.Warn = func(pos token.Pos, format string, args ...interface{}) { + res.warnings = append(res.warnings, warning{pos, format, args}) + } + var err error + res.q, err = minfo.impl(o, qpos) if err != nil { return nil, err } - res.mode = mode - res.fset = o.prog.Fset - res.fprintf = o.fprintf // captures o.prog, o.{start,end}Pos for later printing - return &res, nil + return res, nil +} + +// ParseQueryPos parses the source query position pos. +// If needExact, it must identify a single AST subtree. +// +func ParseQueryPos(imp *importer.Importer, pos string, needExact bool) (*QueryPos, error) { + start, end, err := parseQueryPos(imp.Fset, pos) + if err != nil { + return nil, err + } + info, path, exact := imp.PathEnclosingInterval(start, end) + if path == nil { + return nil, errors.New("no syntax here") + } + if needExact && !exact { + return nil, fmt.Errorf("ambiguous selection within %s", importer.NodeDescription(path[0])) + } + return &QueryPos{start, end, info, path}, nil } // WriteTo writes the oracle query result res to out in a compiler diagnostic format. @@ -262,7 +350,7 @@ func (res *Result) WriteTo(out io.Writer) { // buildSSA constructs the SSA representation of Go-source function bodies. // Not needed in simpler modes, e.g. freevars. // -func buildSSA(o *oracle) { +func buildSSA(o *Oracle) { start := time.Now() o.prog.BuildAll() o.timers["SSA-build"] = time.Since(start) @@ -271,7 +359,7 @@ func buildSSA(o *oracle) { // ptrAnalysis runs the pointer analysis and returns the synthetic // root of the callgraph. // -func ptrAnalysis(o *oracle) pointer.CallGraphNode { +func ptrAnalysis(o *Oracle) pointer.CallGraphNode { start := time.Now() root := pointer.Analyze(&o.config) o.timers["pointer analysis"] = time.Since(start) @@ -399,15 +487,14 @@ func deref(typ types.Type) types.Type { // - an ast.Node, denoting an interval // - anything with a Pos() method: // ssa.Member, ssa.Value, ssa.Instruction, types.Object, pointer.Label, etc. -// - a bool, meaning the extent [o.startPos, o.endPos) of the user's query. -// (the value is ignored) +// - a QueryPos, denoting the extent of the user's query. // - nil, meaning no position at all. // // The output format is is compatible with the 'gnu' // compilation-error-regexp in Emacs' compilation mode. // TODO(adonovan): support other editors. // -func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...interface{}) { +func (o *Oracle) fprintf(w io.Writer, pos interface{}, format string, args ...interface{}) { var start, end token.Pos switch pos := pos.(type) { case ast.Node: @@ -421,9 +508,9 @@ func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...in }: start = pos.Pos() end = start - case bool: - start = o.startPos - end = o.endPos + case *QueryPos: + start = pos.start + end = pos.end case nil: // no-op default: @@ -448,14 +535,14 @@ func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...in } // errorf is like fprintf, but returns a formatted error string. -func (o *oracle) errorf(pos interface{}, format string, args ...interface{}) error { +func (o *Oracle) errorf(pos interface{}, format string, args ...interface{}) error { var buf bytes.Buffer o.fprintf(&buf, pos, format, args...) return errors.New(buf.String()) } // printNode returns the pretty-printed syntax of n. -func (o *oracle) printNode(n ast.Node) string { +func (o *Oracle) printNode(n ast.Node) string { var buf bytes.Buffer printer.Fprint(&buf, o.prog.Fset, n) return buf.String() diff --git a/oracle/oracle_test.go b/oracle/oracle_test.go index 12b9c687..097ecf4c 100644 --- a/oracle/oracle_test.go +++ b/oracle/oracle_test.go @@ -44,6 +44,7 @@ import ( "strings" "testing" + "code.google.com/p/go.tools/importer" "code.google.com/p/go.tools/oracle" ) @@ -249,3 +250,53 @@ func TestOracle(t *testing.T) { } } } + +func TestMultipleQueries(t *testing.T) { + // Importer + var buildContext = build.Default + buildContext.GOPATH = "testdata" + imp := importer.New(&importer.Config{Build: &buildContext}) + + // Oracle + filename := "testdata/src/main/multi.go" + o, err := oracle.New(imp, []string{filename}, nil) + if err != nil { + t.Fatalf("oracle.New failed: %s", err) + } + + // QueryPos + pos := filename + ":#54,#58" + qpos, err := oracle.ParseQueryPos(imp, pos, true) + if err != nil { + t.Fatalf("oracle.ParseQueryPos(%q) failed: %s", pos, err) + } + // SSA is built and we have the QueryPos. + // Release the other ASTs and type info to the GC. + imp = nil + + // Run different query moes on same scope and selection. + out := new(bytes.Buffer) + for _, mode := range [...]string{"callers", "describe", "freevars"} { + res, err := o.Query(mode, qpos) + if err != nil { + t.Errorf("(*oracle.Oracle).Query(%q) failed: %s", pos, err) + } + capture := new(bytes.Buffer) // capture standard output + res.WriteTo(capture) + for _, line := range strings.Split(capture.String(), "\n") { + fmt.Fprintf(out, "%s\n", stripLocation(line)) + } + } + want := `multi.f is called from these 1 sites: + static function call from multi.main + +function call (or conversion) of type () + +Free identifiers: +var x int + +` + if got := out.String(); got != want { + t.Errorf("Query output differs; want <<%s>>, got <<%s>>\n", want, got) + } +} diff --git a/oracle/peers.go b/oracle/peers.go index 23dd1bef..09dba85d 100644 --- a/oracle/peers.go +++ b/oracle/peers.go @@ -22,10 +22,10 @@ import ( // TODO(adonovan): permit the user to query based on a MakeChan (not send/recv), // or the implicit receive in "for v := range ch". // -func peers(o *oracle) (queryResult, error) { - arrowPos := findArrow(o) +func peers(o *Oracle, qpos *QueryPos) (queryResult, error) { + arrowPos := findArrow(qpos) if arrowPos == token.NoPos { - return nil, o.errorf(o.queryPath[0], "there is no send/receive here") + return nil, o.errorf(qpos.path[0], "there is no send/receive here") } buildSSA(o) @@ -111,8 +111,8 @@ func peers(o *oracle) (queryResult, error) { // findArrow returns the position of the enclosing send/receive op // (<-) for the query position, or token.NoPos if not found. // -func findArrow(o *oracle) token.Pos { - for _, n := range o.queryPath { +func findArrow(qpos *QueryPos) token.Pos { + for _, n := range qpos.path { switch n := n.(type) { case *ast.UnaryExpr: if n.Op == token.ARROW { diff --git a/oracle/referrers.go b/oracle/referrers.go index ab369d07..4bf16bd8 100644 --- a/oracle/referrers.go +++ b/oracle/referrers.go @@ -16,16 +16,16 @@ import ( // Referrers reports all identifiers that resolve to the same object // as the queried identifier, within any package in the analysis scope. // -func referrers(o *oracle) (queryResult, error) { - id, _ := o.queryPath[0].(*ast.Ident) +func referrers(o *Oracle, qpos *QueryPos) (queryResult, error) { + id, _ := qpos.path[0].(*ast.Ident) if id == nil { - return nil, o.errorf(false, "no identifier here") + return nil, o.errorf(qpos, "no identifier here") } - obj := o.queryPkgInfo.ObjectOf(id) + obj := qpos.info.ObjectOf(id) if obj == nil { // Happens for y in "switch y := x.(type)", but I think that's all. - return nil, o.errorf(false, "no object for identifier") + return nil, o.errorf(qpos, "no object for identifier") } // Iterate over all go/types' resolver facts for the entire program. diff --git a/oracle/testdata/src/main/calls.golden b/oracle/testdata/src/main/calls.golden index e0312884..225570a7 100644 --- a/oracle/testdata/src/main/calls.golden +++ b/oracle/testdata/src/main/calls.golden @@ -80,7 +80,6 @@ Error: this is a type conversion, not a function call -------- @callees callees-err-bad-selection -------- Error: ambiguous selection within function call (or conversion) - -------- @callees callees-err-deadcode1 -------- Error: this call site is unreachable in this analysis diff --git a/oracle/testdata/src/main/freevars.golden b/oracle/testdata/src/main/freevars.golden index 4b29fa45..a4c2c77c 100644 --- a/oracle/testdata/src/main/freevars.golden +++ b/oracle/testdata/src/main/freevars.golden @@ -1,11 +1,11 @@ -------- @freevars fv1 -------- -Free identifers: +Free identifiers: type C main.C const exp int var x int -------- @freevars fv2 -------- -Free identifers: +Free identifiers: var s.t.a int var s.t.b int var s.x int @@ -13,6 +13,6 @@ var x int var y int32 -------- @freevars fv3 -------- -Free identifers: +Free identifiers: var x int diff --git a/oracle/testdata/src/main/multi.go b/oracle/testdata/src/main/multi.go new file mode 100644 index 00000000..54caf15d --- /dev/null +++ b/oracle/testdata/src/main/multi.go @@ -0,0 +1,13 @@ +package multi + +func g(x int) { +} + +func f() { + x := 1 + g(x) // "g(x)" is the selection for multiple queries +} + +func main() { + f() +}