diff --git a/oracle/callees.go b/oracle/callees.go index afd8d83f..1e0d65e1 100644 --- a/oracle/callees.go +++ b/oracle/callees.go @@ -20,28 +20,28 @@ import ( // // TODO(adonovan): if a callee is a wrapper, show the callee's callee. // -func callees(o *oracle) (queryResult, error) { +func callees(o *Oracle, qpos *QueryPos) (queryResult, error) { // Determine the enclosing call for the specified position. var call *ast.CallExpr - for _, n := range o.queryPath { + for _, n := range qpos.path { if call, _ = n.(*ast.CallExpr); call != nil { break } } if call == nil { - return nil, o.errorf(o.queryPath[0], "there is no function call here") + return nil, o.errorf(qpos.path[0], "there is no function call here") } // TODO(adonovan): issue an error if the call is "too far // away" from the current selection, as this most likely is // not what the user intended. // Reject type conversions. - if o.queryPkgInfo.IsType(call.Fun) { + if qpos.info.IsType(call.Fun) { return nil, o.errorf(call, "this is a type conversion, not a function call") } // Reject calls to built-ins. - if b, ok := o.queryPkgInfo.TypeOf(call.Fun).(*types.Builtin); ok { + if b, ok := qpos.info.TypeOf(call.Fun).(*types.Builtin); ok { return nil, o.errorf(call, "this is a call to the built-in '%s' operator", b.Name()) } diff --git a/oracle/callers.go b/oracle/callers.go index 37549978..01f827be 100644 --- a/oracle/callers.go +++ b/oracle/callers.go @@ -17,20 +17,20 @@ import ( // // TODO(adonovan): if a caller is a wrapper, show the caller's caller. // -func callers(o *oracle) (queryResult, error) { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) +func callers(o *Oracle, qpos *QueryPos) (queryResult, error) { + pkg := o.prog.Package(qpos.info.Pkg) if pkg == nil { - return nil, o.errorf(o.queryPath[0], "no SSA package") + return nil, o.errorf(qpos.path[0], "no SSA package") } - if !ssa.HasEnclosingFunction(pkg, o.queryPath) { - return nil, o.errorf(o.queryPath[0], "this position is not inside a function") + if !ssa.HasEnclosingFunction(pkg, qpos.path) { + return nil, o.errorf(qpos.path[0], "this position is not inside a function") } buildSSA(o) - target := ssa.EnclosingFunction(pkg, o.queryPath) + target := ssa.EnclosingFunction(pkg, qpos.path) if target == nil { - return nil, o.errorf(o.queryPath[0], "no SSA function built for this location (dead code?)") + return nil, o.errorf(qpos.path[0], "no SSA function built for this location (dead code?)") } // Run the pointer analysis, recording each diff --git a/oracle/callgraph.go b/oracle/callgraph.go index 6ebe1912..c8cdd680 100644 --- a/oracle/callgraph.go +++ b/oracle/callgraph.go @@ -25,7 +25,7 @@ import ( // // TODO(adonovan): elide nodes for synthetic functions? // -func callgraph(o *oracle) (queryResult, error) { +func callgraph(o *Oracle, _ *QueryPos) (queryResult, error) { buildSSA(o) // Run the pointer analysis and build the complete callgraph. @@ -35,6 +35,7 @@ func callgraph(o *oracle) (queryResult, error) { // Assign (preorder) numbers to all the callgraph nodes. // TODO(adonovan): the callgraph API should do this for us. + // (Actually, it does have unique numbers under the hood.) numbering := make(map[pointer.CallGraphNode]int) var number func(cgn pointer.CallGraphNode) number = func(cgn pointer.CallGraphNode) { @@ -70,6 +71,8 @@ Some nodes may appear multiple times due to context-sensitive treatment of some calls. `) + // TODO(adonovan): compute the numbers as we print; right now + // it depends on map iteration so it's arbitrary,which is ugly. seen := make(map[pointer.CallGraphNode]bool) var print func(cgn pointer.CallGraphNode, indent int) print = func(cgn pointer.CallGraphNode, indent int) { diff --git a/oracle/callstack.go b/oracle/callstack.go index 16516b6d..6a5f4323 100644 --- a/oracle/callstack.go +++ b/oracle/callstack.go @@ -22,21 +22,21 @@ import ( // TODO(adonovan): permit user to specify a starting point other than // the analysis root. // -func callstack(o *oracle) (queryResult, error) { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) +func callstack(o *Oracle, qpos *QueryPos) (queryResult, error) { + pkg := o.prog.Package(qpos.info.Pkg) if pkg == nil { - return nil, o.errorf(o.queryPath[0], "no SSA package") + return nil, o.errorf(qpos.path[0], "no SSA package") } - if !ssa.HasEnclosingFunction(pkg, o.queryPath) { - return nil, o.errorf(o.queryPath[0], "this position is not inside a function") + if !ssa.HasEnclosingFunction(pkg, qpos.path) { + return nil, o.errorf(qpos.path[0], "this position is not inside a function") } buildSSA(o) - target := ssa.EnclosingFunction(pkg, o.queryPath) + target := ssa.EnclosingFunction(pkg, qpos.path) if target == nil { - return nil, o.errorf(o.queryPath[0], + return nil, o.errorf(qpos.path[0], "no SSA function built for this location (dead code?)") } @@ -74,19 +74,21 @@ func callstack(o *oracle) (queryResult, error) { } return &callstackResult{ + qpos: qpos, target: target, callstack: callstack, }, nil } type callstackResult struct { + qpos *QueryPos target *ssa.Function callstack []pointer.CallSite } func (r *callstackResult) display(printf printfFunc) { if r.callstack != nil { - printf(false, "Found a call path from root to %s", r.target) + printf(r.qpos, "Found a call path from root to %s", r.target) printf(r.target, "%s", r.target) for _, site := range r.callstack { printf(site, "%s from %s", site.Description(), site.Caller().Func()) diff --git a/oracle/describe.go b/oracle/describe.go index 143c7e4c..2c30823c 100644 --- a/oracle/describe.go +++ b/oracle/describe.go @@ -33,25 +33,25 @@ import ( // // All printed sets are sorted to ensure determinism. // -func describe(o *oracle) (queryResult, error) { +func describe(o *Oracle, qpos *QueryPos) (queryResult, error) { if false { // debugging - o.fprintf(os.Stderr, o.queryPath[0], "you selected: %s %s", - importer.NodeDescription(o.queryPath[0]), pathToString2(o.queryPath)) + o.fprintf(os.Stderr, qpos.path[0], "you selected: %s %s", + importer.NodeDescription(qpos.path[0]), pathToString2(qpos.path)) } - path, action := findInterestingNode(o.queryPkgInfo, o.queryPath) + path, action := findInterestingNode(qpos.info, qpos.path) switch action { case actionExpr: - return describeValue(o, path) + return describeValue(o, qpos, path) case actionType: - return describeType(o, path) + return describeType(o, qpos, path) case actionPackage: - return describePackage(o, path) + return describePackage(o, qpos, path) case actionStmt: - return describeStmt(o, path) + return describeStmt(o, qpos, path) case actionUnknown: return &describeUnknownResult{path[0]}, nil @@ -305,11 +305,11 @@ func findInterestingNode(pkginfo *importer.PackageInfo, path []ast.Node) ([]ast. // to the root of the AST is path. It may return a nil Value without // an error to indicate the pointer analysis is not appropriate. // -func ssaValueForIdent(o *oracle, obj types.Object, path []ast.Node) (ssa.Value, error) { +func ssaValueForIdent(prog *ssa.Program, qinfo *importer.PackageInfo, obj types.Object, path []ast.Node) (ssa.Value, error) { if obj, ok := obj.(*types.Var); ok { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) + pkg := prog.Package(qinfo.Pkg) pkg.Build() - if v := o.prog.VarValue(obj, pkg, path); v != nil { + if v := prog.VarValue(obj, pkg, path); v != nil { // Don't run pointer analysis on a ref to a const expression. if _, ok := v.(*ssa.Const); ok { v = nil @@ -328,8 +328,8 @@ func ssaValueForIdent(o *oracle, obj types.Object, path []ast.Node) (ssa.Value, // return a nil Value without an error to indicate the pointer // analysis is not appropriate. // -func ssaValueForExpr(o *oracle, path []ast.Node) (ssa.Value, error) { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) +func ssaValueForExpr(prog *ssa.Program, qinfo *importer.PackageInfo, path []ast.Node) (ssa.Value, error) { + pkg := prog.Package(qinfo.Pkg) pkg.SetDebugMode(true) pkg.Build() @@ -345,7 +345,7 @@ func ssaValueForExpr(o *oracle, path []ast.Node) (ssa.Value, error) { return nil, fmt.Errorf("can't locate SSA Value for expression in %s", fn) } -func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { +func describeValue(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeValueResult, error) { var expr ast.Expr var obj types.Object switch n := path[0].(type) { @@ -353,7 +353,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { // ambiguous ValueSpec containing multiple names return nil, o.errorf(n, "multiple value specification") case *ast.Ident: - obj = o.queryPkgInfo.ObjectOf(n) + obj = qpos.info.ObjectOf(n) expr = n case ast.Expr: expr = n @@ -362,8 +362,8 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { return nil, o.errorf(n, "unexpected AST for expr: %T", n) } - typ := o.queryPkgInfo.TypeOf(expr) - constVal := o.queryPkgInfo.ValueOf(expr) + typ := qpos.info.TypeOf(expr) + constVal := qpos.info.ValueOf(expr) // From this point on, we cannot fail with an error. // Failure to run the pointer analysis will be reported later. @@ -386,11 +386,11 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { var value ssa.Value if obj != nil { // def/ref of func/var/const object - value, ptaErr = ssaValueForIdent(o, obj, path) + value, ptaErr = ssaValueForIdent(o.prog, qpos.info, obj, path) } else { // any other expression - if o.queryPkgInfo.ValueOf(path[0].(ast.Expr)) == nil { // non-constant? - value, ptaErr = ssaValueForExpr(o, path) + if qpos.info.ValueOf(path[0].(ast.Expr)) == nil { // non-constant? + value, ptaErr = ssaValueForExpr(o.prog, qpos.info, path) } } if value != nil { @@ -404,6 +404,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { } return &describeValueResult{ + qpos: qpos, expr: expr, typ: typ, constVal: constVal, @@ -414,7 +415,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { } // describePointer runs the pointer analysis of the selected SSA value. -func describePointer(o *oracle, v ssa.Value, indirect bool) (ptrs []pointerResult, err error) { +func describePointer(o *Oracle, v ssa.Value, indirect bool) (ptrs []pointerResult, err error) { buildSSA(o) // TODO(adonovan): don't run indirect pointer analysis on non-ptr-ptrlike types. @@ -454,6 +455,7 @@ type pointerResult struct { } type describeValueResult struct { + qpos *QueryPos expr ast.Expr // query node typ types.Type // type of expression constVal exact.Value // value of expression, if constant @@ -524,7 +526,7 @@ func (r *describeValueResult) display(printf printfFunc) { // reflect.Value expression. if len(r.ptrs) > 0 { - printf(false, "this %s may contain these dynamic types:", r.typ) + printf(r.qpos, "this %s may contain these dynamic types:", r.typ) for _, ptr := range r.ptrs { var obj types.Object if nt, ok := deref(ptr.typ).(*types.Named); ok { @@ -538,15 +540,15 @@ func (r *describeValueResult) display(printf printfFunc) { } } } else { - printf(false, "this %s cannot contain any dynamic types.", r.typ) + printf(r.qpos, "this %s cannot contain any dynamic types.", r.typ) } } else { // Show labels for other expressions. if ptr := r.ptrs[0]; len(ptr.labels) > 0 { - printf(false, "value may point to these labels:") + printf(r.qpos, "value may point to these labels:") printLabels(printf, ptr.labels, "\t") } else { - printf(false, "value cannot point to anything.") + printf(r.qpos, "value cannot point to anything.") } } } @@ -622,12 +624,12 @@ func printLabels(printf printfFunc, labels []*pointer.Label, prefix string) { // ---- TYPE ------------------------------------------------------------ -func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) { +func describeType(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeTypeResult, error) { var description string var t types.Type switch n := path[0].(type) { case *ast.Ident: - t = o.queryPkgInfo.TypeOf(n) + t = qpos.info.TypeOf(n) switch t := t.(type) { case *types.Basic: description = "reference to built-in type " + t.String() @@ -642,7 +644,7 @@ func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) { } case ast.Expr: - t = o.queryPkgInfo.TypeOf(n) + t = qpos.info.TypeOf(n) description = "type " + t.String() default: @@ -654,7 +656,7 @@ func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) { node: path[0], description: description, typ: t, - methods: accessibleMethods(t, o.queryPkgInfo.Pkg), + methods: accessibleMethods(t, qpos.info.Pkg), }, nil } @@ -708,7 +710,7 @@ func (r *describeTypeResult) toJSON(res *json.Result, fset *token.FileSet) { // ---- PACKAGE ------------------------------------------------------------ -func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error) { +func describePackage(o *Oracle, qpos *QueryPos, path []ast.Node) (*describePackageResult, error) { var description string var pkg *types.Package switch n := path[0].(type) { @@ -724,12 +726,12 @@ func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error) case *ast.Ident: if _, isDef := path[1].(*ast.File); isDef { // e.g. package id - pkg = o.queryPkgInfo.Pkg + pkg = qpos.info.Pkg description = fmt.Sprintf("definition of package %q", pkg.Path()) } else { // e.g. import id // or id.F() - pkg = o.queryPkgInfo.ObjectOf(n).Pkg() + pkg = qpos.info.ObjectOf(n).Pkg() description = fmt.Sprintf("reference to package %q", pkg.Path()) } @@ -744,11 +746,11 @@ func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error) // Enumerate the accessible package members // in lexicographic order. for _, name := range pkg.Scope().Names() { - if pkg == o.queryPkgInfo.Pkg || ast.IsExported(name) { + if pkg == qpos.info.Pkg || ast.IsExported(name) { mem := pkg.Scope().Lookup(name) var methods []*types.Selection if mem, ok := mem.(*types.TypeName); ok { - methods = accessibleMethods(mem.Type(), o.queryPkgInfo.Pkg) + methods = accessibleMethods(mem.Type(), qpos.info.Pkg) } members = append(members, &describeMember{ mem, @@ -878,11 +880,11 @@ func tokenOf(o types.Object) string { // ---- STATEMENT ------------------------------------------------------------ -func describeStmt(o *oracle, path []ast.Node) (*describeStmtResult, error) { +func describeStmt(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeStmtResult, error) { var description string switch n := path[0].(type) { case *ast.Ident: - if o.queryPkgInfo.ObjectOf(n).Pos() == n.Pos() { + if qpos.info.ObjectOf(n).Pos() == n.Pos() { description = "labelled statement" } else { description = "reference to labelled statement" diff --git a/oracle/freevars.go b/oracle/freevars.go index 07973f2a..5e5d7548 100644 --- a/oracle/freevars.go +++ b/oracle/freevars.go @@ -26,9 +26,9 @@ import ( // these might be interesting. Perhaps group the results into three // bands. // -func freevars(o *oracle) (queryResult, error) { - file := o.queryPath[len(o.queryPath)-1] // the enclosing file - fileScope := o.queryPkgInfo.Scopes[file] +func freevars(o *Oracle, qpos *QueryPos) (queryResult, error) { + file := qpos.path[len(qpos.path)-1] // the enclosing file + fileScope := qpos.info.Scopes[file] pkgScope := fileScope.Parent() // The id and sel functions return non-nil if they denote an @@ -49,7 +49,7 @@ func freevars(o *oracle) (queryResult, error) { } id = func(n *ast.Ident) types.Object { - obj := o.queryPkgInfo.ObjectOf(n) + obj := qpos.info.ObjectOf(n) if obj == nil { return nil // TODO(adonovan): fix: this fails for *types.Label. panic(o.errorf(n, "no types.Object for ast.Ident")) @@ -70,7 +70,7 @@ func freevars(o *oracle) (queryResult, error) { if scope == fileScope || scope == pkgScope { return nil // defined at file or package scope } - if o.startPos <= obj.Pos() && obj.Pos() <= o.endPos { + if qpos.start <= obj.Pos() && obj.Pos() <= qpos.end { return nil // defined within selection => not free } return obj @@ -82,7 +82,7 @@ func freevars(o *oracle) (queryResult, error) { refsMap := make(map[string]freevarsRef) // Visit all the identifiers in the selected ASTs. - ast.Inspect(o.queryPath[0], func(n ast.Node) bool { + ast.Inspect(qpos.path[0], func(n ast.Node) bool { if n == nil { return true // popping DFS stack } @@ -90,7 +90,7 @@ func freevars(o *oracle) (queryResult, error) { // Is this node contained within the selection? // (freevars permits inexact selections, // like two stmts in a block.) - if o.startPos <= n.Pos() && n.End() <= o.endPos { + if qpos.start <= n.Pos() && n.End() <= qpos.end { var obj types.Object var prune bool switch n := n.(type) { @@ -119,7 +119,7 @@ func freevars(o *oracle) (queryResult, error) { panic(obj) } - typ := o.queryPkgInfo.TypeOf(n.(ast.Expr)) + typ := qpos.info.TypeOf(n.(ast.Expr)) ref := freevarsRef{kind, o.printNode(n), typ, obj} refsMap[ref.ref] = ref @@ -139,12 +139,14 @@ func freevars(o *oracle) (queryResult, error) { sort.Sort(byRef(refs)) return &freevarsResult{ + qpos: qpos, fset: o.prog.Fset, refs: refs, }, nil } type freevarsResult struct { + qpos *QueryPos fset *token.FileSet refs []freevarsRef } @@ -158,9 +160,9 @@ type freevarsRef struct { func (r *freevarsResult) display(printf printfFunc) { if len(r.refs) == 0 { - printf(false, "No free identifers.") + printf(r.qpos, "No free identifiers.") } else { - printf(false, "Free identifers:") + printf(r.qpos, "Free identifiers:") for _, ref := range r.refs { printf(ref.obj, "%s %s %s", ref.kind, ref.ref, ref.typ) } diff --git a/oracle/implements.go b/oracle/implements.go index b31617a2..d26cd1ec 100644 --- a/oracle/implements.go +++ b/oracle/implements.go @@ -30,8 +30,8 @@ import ( // actually occur, with examples? (NB: this is not a conservative // answer due to ChangeInterface, i.e. subtyping among interfaces.) // -func implements(o *oracle) (queryResult, error) { - pkg := o.queryPkgInfo.Pkg +func implements(o *Oracle, qpos *QueryPos) (queryResult, error) { + pkg := qpos.info.Pkg // Compute set of named interface/concrete types at package level. var interfaces, concretes []*types.Named diff --git a/oracle/oracle.go b/oracle/oracle.go index e462ed54..b53095cc 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -2,6 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Package oracle contains the implementation of the oracle tool whose +// command-line is provided by code.google.com/p/go.tools/cmd/oracle. +// +// http://golang.org/s/oracle-design +// http://golang.org/s/oracle-user-manual +// package oracle // This file defines oracle.Query, the entry point for the oracle tool. @@ -33,16 +39,12 @@ import ( "code.google.com/p/go.tools/ssa" ) -type oracle struct { +// An Oracle holds the program state required for one or more queries. +type Oracle struct { out io.Writer // standard output prog *ssa.Program // the SSA program [only populated if need&SSA] config pointer.Config // pointer analysis configuration - // need&(Pos|ExactPos): - startPos, endPos token.Pos // source extent of query - queryPkgInfo *importer.PackageInfo // type info for the queried package - queryPath []ast.Node // AST path from query node to root of ast.File - // need&AllTypeInfo typeInfo map[*types.Package]*importer.PackageInfo // type info for all ASTs in the program @@ -53,30 +55,42 @@ type oracle struct { // // Typed ASTs for the whole program are always constructed // transiently; they are retained only for the queried package unless -// AllTypeInfo is set. +// needAllTypeInfo is set. const ( - Pos = 1 << iota // needs a position - ExactPos // needs an exact AST selection; implies Pos - AllTypeInfo // needs to retain type info for all ASTs in the program - SSA // needs ssa.Packages for whole program - PTA = SSA // needs pointer analysis + needPos = 1 << iota // needs a position + needExactPos // needs an exact AST selection; implies needPos + needAllTypeInfo // needs to retain type info for all ASTs in the program + needSSA // needs ssa.Packages for whole program + needSSADebug // needs debug info for ssa.Packages + needPTA = needSSA // needs pointer analysis + needAll = -1 // needs everything (e.g. a sequence of queries) ) type modeInfo struct { + name string needs int - impl func(*oracle) (queryResult, error) + impl func(*Oracle, *QueryPos) (queryResult, error) } -var modes = map[string]modeInfo{ - "callees": modeInfo{PTA | ExactPos, callees}, - "callers": modeInfo{PTA | Pos, callers}, - "callgraph": modeInfo{PTA, callgraph}, - "callstack": modeInfo{PTA | Pos, callstack}, - "describe": modeInfo{PTA | ExactPos, describe}, - "freevars": modeInfo{Pos, freevars}, - "implements": modeInfo{Pos, implements}, - "peers": modeInfo{PTA | Pos, peers}, - "referrers": modeInfo{AllTypeInfo | Pos, referrers}, +var modes = []*modeInfo{ + {"callees", needPTA | needExactPos, callees}, + {"callers", needPTA | needPos, callers}, + {"callgraph", needPTA, callgraph}, + {"callstack", needPTA | needPos, callstack}, + {"describe", needPTA | needSSADebug | needExactPos, describe}, + {"freevars", needPos, freevars}, + {"implements", needPos, implements}, + {"peers", needPTA | needSSADebug | needPos, peers}, + {"referrers", needAllTypeInfo | needPos, referrers}, +} + +func findMode(mode string) *modeInfo { + for _, m := range modes { + if m.name == mode { + return m + } + } + return nil } type printfFunc func(pos interface{}, format string, args ...interface{}) @@ -93,6 +107,17 @@ type warning struct { args []interface{} } +// A QueryPos represents the position provided as input to a query: +// a textual extent in the program's source code, the AST node it +// corresponds to, and the package to which it belongs. +// Instances are created by ParseQueryPos. +// +type QueryPos struct { + start, end token.Pos // source extent of query + info *importer.PackageInfo // type info for the queried package + path []ast.Node // AST path from query node to root of ast.File +} + // A Result encapsulates the result of an oracle.Query. // // Result instances implement the json.Marshaler interface, i.e. they @@ -118,41 +143,91 @@ func (res *Result) MarshalJSON() ([]byte, error) { return encjson.Marshal(resj) } -// Query runs the oracle. +// Query runs a single oracle query. +// // args specify the main package in importer.CreatePackageFromArgs syntax. // mode is the query mode ("callers", etc). -// pos is the selection in parseQueryPos() syntax. // ptalog is the (optional) pointer-analysis log file. -// buildContext is the optional configuration for locating packages. +// buildContext is the go/build configuration for locating packages. +// +// Clients that intend to perform multiple queries against the same +// analysis scope should use this pattern instead: +// +// imp := importer.New(&importer.Config{Build: buildContext}) +// o, err := oracle.New(imp, args, nil) +// if err != nil { ... } +// for ... { +// qpos, err := oracle.ParseQueryPos(imp, pos, needExact) +// if err != nil { ... } +// +// res, err := o.Query(mode, qpos) +// if err != nil { ... } +// +// // use res +// } +// +// TODO(adonovan): the ideal 'needsExact' parameter for ParseQueryPos +// depends on the query mode; how should we expose this? // func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *build.Context) (*Result, error) { - minfo, ok := modes[mode] - if !ok { + minfo := findMode(mode) + if minfo == nil { return nil, fmt.Errorf("invalid mode type: %q", mode) } imp := importer.New(&importer.Config{Build: buildContext}) - o := &oracle{ + o, err := New(imp, args, ptalog) + if err != nil { + return nil, err + } + + // Phase timing diagnostics. + // TODO(adonovan): needs more work. + // if false { + // defer func() { + // fmt.Println() + // for name, duration := range o.timers { + // fmt.Printf("# %-30s %s\n", name, duration) + // } + // }() + // } + + var qpos *QueryPos + if minfo.needs&(needPos|needExactPos) != 0 { + var err error + qpos, err = ParseQueryPos(imp, pos, minfo.needs&needExactPos != 0) + if err != nil { + return nil, err + } + } + + // SSA is built and we have the QueryPos. + // Release the other ASTs and type info to the GC. + imp = nil + + return o.query(minfo, qpos) +} + +// New constructs a new Oracle that can be used for a sequence of queries. +// +// imp will be used to load source code for imported packages. +// It must not yet have loaded any packages. +// +// args specify the main package in importer.CreatePackageFromArgs syntax. +// +// ptalog is the (optional) pointer-analysis log file. +// +func New(imp *importer.Importer, args []string, ptalog io.Writer) (*Oracle, error) { + return newOracle(imp, args, ptalog, needAll) +} + +func newOracle(imp *importer.Importer, args []string, ptalog io.Writer, needs int) (*Oracle, error) { + o := &Oracle{ prog: ssa.NewProgram(imp.Fset, 0), timers: make(map[string]time.Duration), } o.config.Log = ptalog - var res Result - o.config.Warn = func(pos token.Pos, format string, args ...interface{}) { - res.warnings = append(res.warnings, warning{pos, format, args}) - } - - // Phase timing diagnostics. - if false { - defer func() { - fmt.Println() - for name, duration := range o.timers { - fmt.Printf("# %-30s %s\n", name, duration) - } - }() - } - // Load/parse/type-check program from args. start := time.Now() initialPkgInfos, args, err := imp.LoadInitialPackages(args) @@ -165,7 +240,7 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil o.timers["load/parse/type"] = time.Since(start) // Retain type info for all ASTs in the program. - if minfo.needs&AllTypeInfo != 0 { + if needs&needAllTypeInfo != 0 { m := make(map[*types.Package]*importer.PackageInfo) for _, p := range imp.AllPackages() { m[p.Pkg] = p @@ -173,32 +248,13 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil o.typeInfo = m } - // Parse the source query position. - if minfo.needs&(Pos|ExactPos) != 0 { - var err error - o.startPos, o.endPos, err = parseQueryPos(o.prog.Fset, pos) - if err != nil { - return nil, err - } - - var exact bool - o.queryPkgInfo, o.queryPath, exact = imp.PathEnclosingInterval(o.startPos, o.endPos) - if o.queryPath == nil { - return nil, o.errorf(false, "no syntax here") - } - if minfo.needs&ExactPos != 0 && !exact { - return nil, o.errorf(o.queryPath[0], "ambiguous selection within %s", - importer.NodeDescription(o.queryPath[0])) - } - } - // Create SSA package for the initial package and its dependencies. - if minfo.needs&SSA != 0 { + if needs&needSSA != 0 { start = time.Now() // Create SSA packages. if err := o.prog.CreatePackages(imp); err != nil { - return nil, o.errorf(false, "%s", err) + return nil, o.errorf(nil, "%s", err) } // Initial packages (specified on command line) @@ -211,34 +267,66 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil // should build a single synthetic testmain package, // not synthetic main functions to many packages. if initialPkg.CreateTestMainFunction() == nil { - return nil, o.errorf(false, "analysis scope has no main() entry points") + return nil, o.errorf(nil, "analysis scope has no main() entry points") } } o.config.Mains = append(o.config.Mains, initialPkg) } - // Query package. - if o.queryPkgInfo != nil { - pkg := o.prog.Package(o.queryPkgInfo.Pkg) - pkg.SetDebugMode(true) - pkg.Build() + if needs&needSSADebug != 0 { + for _, pkg := range o.prog.AllPackages() { + pkg.SetDebugMode(true) + } } o.timers["SSA-create"] = time.Since(start) } - // SSA is built and we have query{Path,PkgInfo}. - // Release the other ASTs and type info to the GC. - imp = nil + return o, nil +} - res.q, err = minfo.impl(o) +// Query runs the query of the specified mode and selection. +func (o *Oracle) Query(mode string, qpos *QueryPos) (*Result, error) { + minfo := findMode(mode) + if minfo == nil { + return nil, fmt.Errorf("invalid mode type: %q", mode) + } + return o.query(minfo, qpos) +} + +func (o *Oracle) query(minfo *modeInfo, qpos *QueryPos) (*Result, error) { + res := &Result{ + mode: minfo.name, + fset: o.prog.Fset, + fprintf: o.fprintf, // captures o.prog, o.{start,end}Pos for later printing + } + o.config.Warn = func(pos token.Pos, format string, args ...interface{}) { + res.warnings = append(res.warnings, warning{pos, format, args}) + } + var err error + res.q, err = minfo.impl(o, qpos) if err != nil { return nil, err } - res.mode = mode - res.fset = o.prog.Fset - res.fprintf = o.fprintf // captures o.prog, o.{start,end}Pos for later printing - return &res, nil + return res, nil +} + +// ParseQueryPos parses the source query position pos. +// If needExact, it must identify a single AST subtree. +// +func ParseQueryPos(imp *importer.Importer, pos string, needExact bool) (*QueryPos, error) { + start, end, err := parseQueryPos(imp.Fset, pos) + if err != nil { + return nil, err + } + info, path, exact := imp.PathEnclosingInterval(start, end) + if path == nil { + return nil, errors.New("no syntax here") + } + if needExact && !exact { + return nil, fmt.Errorf("ambiguous selection within %s", importer.NodeDescription(path[0])) + } + return &QueryPos{start, end, info, path}, nil } // WriteTo writes the oracle query result res to out in a compiler diagnostic format. @@ -262,7 +350,7 @@ func (res *Result) WriteTo(out io.Writer) { // buildSSA constructs the SSA representation of Go-source function bodies. // Not needed in simpler modes, e.g. freevars. // -func buildSSA(o *oracle) { +func buildSSA(o *Oracle) { start := time.Now() o.prog.BuildAll() o.timers["SSA-build"] = time.Since(start) @@ -271,7 +359,7 @@ func buildSSA(o *oracle) { // ptrAnalysis runs the pointer analysis and returns the synthetic // root of the callgraph. // -func ptrAnalysis(o *oracle) pointer.CallGraphNode { +func ptrAnalysis(o *Oracle) pointer.CallGraphNode { start := time.Now() root := pointer.Analyze(&o.config) o.timers["pointer analysis"] = time.Since(start) @@ -399,15 +487,14 @@ func deref(typ types.Type) types.Type { // - an ast.Node, denoting an interval // - anything with a Pos() method: // ssa.Member, ssa.Value, ssa.Instruction, types.Object, pointer.Label, etc. -// - a bool, meaning the extent [o.startPos, o.endPos) of the user's query. -// (the value is ignored) +// - a QueryPos, denoting the extent of the user's query. // - nil, meaning no position at all. // // The output format is is compatible with the 'gnu' // compilation-error-regexp in Emacs' compilation mode. // TODO(adonovan): support other editors. // -func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...interface{}) { +func (o *Oracle) fprintf(w io.Writer, pos interface{}, format string, args ...interface{}) { var start, end token.Pos switch pos := pos.(type) { case ast.Node: @@ -421,9 +508,9 @@ func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...in }: start = pos.Pos() end = start - case bool: - start = o.startPos - end = o.endPos + case *QueryPos: + start = pos.start + end = pos.end case nil: // no-op default: @@ -448,14 +535,14 @@ func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...in } // errorf is like fprintf, but returns a formatted error string. -func (o *oracle) errorf(pos interface{}, format string, args ...interface{}) error { +func (o *Oracle) errorf(pos interface{}, format string, args ...interface{}) error { var buf bytes.Buffer o.fprintf(&buf, pos, format, args...) return errors.New(buf.String()) } // printNode returns the pretty-printed syntax of n. -func (o *oracle) printNode(n ast.Node) string { +func (o *Oracle) printNode(n ast.Node) string { var buf bytes.Buffer printer.Fprint(&buf, o.prog.Fset, n) return buf.String() diff --git a/oracle/oracle_test.go b/oracle/oracle_test.go index 12b9c687..097ecf4c 100644 --- a/oracle/oracle_test.go +++ b/oracle/oracle_test.go @@ -44,6 +44,7 @@ import ( "strings" "testing" + "code.google.com/p/go.tools/importer" "code.google.com/p/go.tools/oracle" ) @@ -249,3 +250,53 @@ func TestOracle(t *testing.T) { } } } + +func TestMultipleQueries(t *testing.T) { + // Importer + var buildContext = build.Default + buildContext.GOPATH = "testdata" + imp := importer.New(&importer.Config{Build: &buildContext}) + + // Oracle + filename := "testdata/src/main/multi.go" + o, err := oracle.New(imp, []string{filename}, nil) + if err != nil { + t.Fatalf("oracle.New failed: %s", err) + } + + // QueryPos + pos := filename + ":#54,#58" + qpos, err := oracle.ParseQueryPos(imp, pos, true) + if err != nil { + t.Fatalf("oracle.ParseQueryPos(%q) failed: %s", pos, err) + } + // SSA is built and we have the QueryPos. + // Release the other ASTs and type info to the GC. + imp = nil + + // Run different query moes on same scope and selection. + out := new(bytes.Buffer) + for _, mode := range [...]string{"callers", "describe", "freevars"} { + res, err := o.Query(mode, qpos) + if err != nil { + t.Errorf("(*oracle.Oracle).Query(%q) failed: %s", pos, err) + } + capture := new(bytes.Buffer) // capture standard output + res.WriteTo(capture) + for _, line := range strings.Split(capture.String(), "\n") { + fmt.Fprintf(out, "%s\n", stripLocation(line)) + } + } + want := `multi.f is called from these 1 sites: + static function call from multi.main + +function call (or conversion) of type () + +Free identifiers: +var x int + +` + if got := out.String(); got != want { + t.Errorf("Query output differs; want <<%s>>, got <<%s>>\n", want, got) + } +} diff --git a/oracle/peers.go b/oracle/peers.go index 23dd1bef..09dba85d 100644 --- a/oracle/peers.go +++ b/oracle/peers.go @@ -22,10 +22,10 @@ import ( // TODO(adonovan): permit the user to query based on a MakeChan (not send/recv), // or the implicit receive in "for v := range ch". // -func peers(o *oracle) (queryResult, error) { - arrowPos := findArrow(o) +func peers(o *Oracle, qpos *QueryPos) (queryResult, error) { + arrowPos := findArrow(qpos) if arrowPos == token.NoPos { - return nil, o.errorf(o.queryPath[0], "there is no send/receive here") + return nil, o.errorf(qpos.path[0], "there is no send/receive here") } buildSSA(o) @@ -111,8 +111,8 @@ func peers(o *oracle) (queryResult, error) { // findArrow returns the position of the enclosing send/receive op // (<-) for the query position, or token.NoPos if not found. // -func findArrow(o *oracle) token.Pos { - for _, n := range o.queryPath { +func findArrow(qpos *QueryPos) token.Pos { + for _, n := range qpos.path { switch n := n.(type) { case *ast.UnaryExpr: if n.Op == token.ARROW { diff --git a/oracle/referrers.go b/oracle/referrers.go index ab369d07..4bf16bd8 100644 --- a/oracle/referrers.go +++ b/oracle/referrers.go @@ -16,16 +16,16 @@ import ( // Referrers reports all identifiers that resolve to the same object // as the queried identifier, within any package in the analysis scope. // -func referrers(o *oracle) (queryResult, error) { - id, _ := o.queryPath[0].(*ast.Ident) +func referrers(o *Oracle, qpos *QueryPos) (queryResult, error) { + id, _ := qpos.path[0].(*ast.Ident) if id == nil { - return nil, o.errorf(false, "no identifier here") + return nil, o.errorf(qpos, "no identifier here") } - obj := o.queryPkgInfo.ObjectOf(id) + obj := qpos.info.ObjectOf(id) if obj == nil { // Happens for y in "switch y := x.(type)", but I think that's all. - return nil, o.errorf(false, "no object for identifier") + return nil, o.errorf(qpos, "no object for identifier") } // Iterate over all go/types' resolver facts for the entire program. diff --git a/oracle/testdata/src/main/calls.golden b/oracle/testdata/src/main/calls.golden index e0312884..225570a7 100644 --- a/oracle/testdata/src/main/calls.golden +++ b/oracle/testdata/src/main/calls.golden @@ -80,7 +80,6 @@ Error: this is a type conversion, not a function call -------- @callees callees-err-bad-selection -------- Error: ambiguous selection within function call (or conversion) - -------- @callees callees-err-deadcode1 -------- Error: this call site is unreachable in this analysis diff --git a/oracle/testdata/src/main/freevars.golden b/oracle/testdata/src/main/freevars.golden index 4b29fa45..a4c2c77c 100644 --- a/oracle/testdata/src/main/freevars.golden +++ b/oracle/testdata/src/main/freevars.golden @@ -1,11 +1,11 @@ -------- @freevars fv1 -------- -Free identifers: +Free identifiers: type C main.C const exp int var x int -------- @freevars fv2 -------- -Free identifers: +Free identifiers: var s.t.a int var s.t.b int var s.x int @@ -13,6 +13,6 @@ var x int var y int32 -------- @freevars fv3 -------- -Free identifers: +Free identifiers: var x int diff --git a/oracle/testdata/src/main/multi.go b/oracle/testdata/src/main/multi.go new file mode 100644 index 00000000..54caf15d --- /dev/null +++ b/oracle/testdata/src/main/multi.go @@ -0,0 +1,13 @@ +package multi + +func g(x int) { +} + +func f() { + x := 1 + g(x) // "g(x)" is the selection for multiple queries +} + +func main() { + f() +}