go.tools/oracle: refactor Oracle API to allow repeated queries on same scope.

The existing standalone Query function builds an importer, ssa.Program, oracle,
and query position, executes the query and returns the result.
For clients (such as Frederik Zipp's web-based github.com/fzipp/pythia tool)
that wish to load the program once and make several queries, we now expose
these as separate operations too.  Here's a client, in pseudocode:

        o := oracle.New(...)
        for ... {
                qpos := o.ParseQueryPos(...)
                res := o.Query(mode, qpos)
                print result
        }

NB: this is a slight deoptimisation in the one-shot case since we have to
build the entire SSA program with debug info, not just the query package,
since we now don't know the query package at that time.

The 'exact' param to ParseQueryPos needs more thought since its
ideal value is a function of the query mode.  This will do for now.

Details:
- expose Oracle type, New() func and Query() method.
- expose QueryPos type and ParseQueryPos func.
- improved package doc comment.
- un-exposed the "needs" bits.
- added test.

R=crawshaw
CC=frederik.zipp, golang-dev
https://golang.org/cl/13810043
This commit is contained in:
Alan Donovan 2013-09-23 15:02:18 -04:00
parent eb130cb481
commit 25a0cc4bfd
14 changed files with 332 additions and 173 deletions

View File

@ -20,28 +20,28 @@ import (
// //
// TODO(adonovan): if a callee is a wrapper, show the callee's callee. // TODO(adonovan): if a callee is a wrapper, show the callee's callee.
// //
func callees(o *oracle) (queryResult, error) { func callees(o *Oracle, qpos *QueryPos) (queryResult, error) {
// Determine the enclosing call for the specified position. // Determine the enclosing call for the specified position.
var call *ast.CallExpr var call *ast.CallExpr
for _, n := range o.queryPath { for _, n := range qpos.path {
if call, _ = n.(*ast.CallExpr); call != nil { if call, _ = n.(*ast.CallExpr); call != nil {
break break
} }
} }
if call == nil { if call == nil {
return nil, o.errorf(o.queryPath[0], "there is no function call here") return nil, o.errorf(qpos.path[0], "there is no function call here")
} }
// TODO(adonovan): issue an error if the call is "too far // TODO(adonovan): issue an error if the call is "too far
// away" from the current selection, as this most likely is // away" from the current selection, as this most likely is
// not what the user intended. // not what the user intended.
// Reject type conversions. // Reject type conversions.
if o.queryPkgInfo.IsType(call.Fun) { if qpos.info.IsType(call.Fun) {
return nil, o.errorf(call, "this is a type conversion, not a function call") return nil, o.errorf(call, "this is a type conversion, not a function call")
} }
// Reject calls to built-ins. // Reject calls to built-ins.
if b, ok := o.queryPkgInfo.TypeOf(call.Fun).(*types.Builtin); ok { if b, ok := qpos.info.TypeOf(call.Fun).(*types.Builtin); ok {
return nil, o.errorf(call, "this is a call to the built-in '%s' operator", b.Name()) return nil, o.errorf(call, "this is a call to the built-in '%s' operator", b.Name())
} }

View File

@ -17,20 +17,20 @@ import (
// //
// TODO(adonovan): if a caller is a wrapper, show the caller's caller. // TODO(adonovan): if a caller is a wrapper, show the caller's caller.
// //
func callers(o *oracle) (queryResult, error) { func callers(o *Oracle, qpos *QueryPos) (queryResult, error) {
pkg := o.prog.Package(o.queryPkgInfo.Pkg) pkg := o.prog.Package(qpos.info.Pkg)
if pkg == nil { if pkg == nil {
return nil, o.errorf(o.queryPath[0], "no SSA package") return nil, o.errorf(qpos.path[0], "no SSA package")
} }
if !ssa.HasEnclosingFunction(pkg, o.queryPath) { if !ssa.HasEnclosingFunction(pkg, qpos.path) {
return nil, o.errorf(o.queryPath[0], "this position is not inside a function") return nil, o.errorf(qpos.path[0], "this position is not inside a function")
} }
buildSSA(o) buildSSA(o)
target := ssa.EnclosingFunction(pkg, o.queryPath) target := ssa.EnclosingFunction(pkg, qpos.path)
if target == nil { if target == nil {
return nil, o.errorf(o.queryPath[0], "no SSA function built for this location (dead code?)") return nil, o.errorf(qpos.path[0], "no SSA function built for this location (dead code?)")
} }
// Run the pointer analysis, recording each // Run the pointer analysis, recording each

View File

@ -25,7 +25,7 @@ import (
// //
// TODO(adonovan): elide nodes for synthetic functions? // TODO(adonovan): elide nodes for synthetic functions?
// //
func callgraph(o *oracle) (queryResult, error) { func callgraph(o *Oracle, _ *QueryPos) (queryResult, error) {
buildSSA(o) buildSSA(o)
// Run the pointer analysis and build the complete callgraph. // Run the pointer analysis and build the complete callgraph.
@ -35,6 +35,7 @@ func callgraph(o *oracle) (queryResult, error) {
// Assign (preorder) numbers to all the callgraph nodes. // Assign (preorder) numbers to all the callgraph nodes.
// TODO(adonovan): the callgraph API should do this for us. // TODO(adonovan): the callgraph API should do this for us.
// (Actually, it does have unique numbers under the hood.)
numbering := make(map[pointer.CallGraphNode]int) numbering := make(map[pointer.CallGraphNode]int)
var number func(cgn pointer.CallGraphNode) var number func(cgn pointer.CallGraphNode)
number = func(cgn pointer.CallGraphNode) { number = func(cgn pointer.CallGraphNode) {
@ -70,6 +71,8 @@ Some nodes may appear multiple times due to context-sensitive
treatment of some calls. treatment of some calls.
`) `)
// TODO(adonovan): compute the numbers as we print; right now
// it depends on map iteration so it's arbitrary,which is ugly.
seen := make(map[pointer.CallGraphNode]bool) seen := make(map[pointer.CallGraphNode]bool)
var print func(cgn pointer.CallGraphNode, indent int) var print func(cgn pointer.CallGraphNode, indent int)
print = func(cgn pointer.CallGraphNode, indent int) { print = func(cgn pointer.CallGraphNode, indent int) {

View File

@ -22,21 +22,21 @@ import (
// TODO(adonovan): permit user to specify a starting point other than // TODO(adonovan): permit user to specify a starting point other than
// the analysis root. // the analysis root.
// //
func callstack(o *oracle) (queryResult, error) { func callstack(o *Oracle, qpos *QueryPos) (queryResult, error) {
pkg := o.prog.Package(o.queryPkgInfo.Pkg) pkg := o.prog.Package(qpos.info.Pkg)
if pkg == nil { if pkg == nil {
return nil, o.errorf(o.queryPath[0], "no SSA package") return nil, o.errorf(qpos.path[0], "no SSA package")
} }
if !ssa.HasEnclosingFunction(pkg, o.queryPath) { if !ssa.HasEnclosingFunction(pkg, qpos.path) {
return nil, o.errorf(o.queryPath[0], "this position is not inside a function") return nil, o.errorf(qpos.path[0], "this position is not inside a function")
} }
buildSSA(o) buildSSA(o)
target := ssa.EnclosingFunction(pkg, o.queryPath) target := ssa.EnclosingFunction(pkg, qpos.path)
if target == nil { if target == nil {
return nil, o.errorf(o.queryPath[0], return nil, o.errorf(qpos.path[0],
"no SSA function built for this location (dead code?)") "no SSA function built for this location (dead code?)")
} }
@ -74,19 +74,21 @@ func callstack(o *oracle) (queryResult, error) {
} }
return &callstackResult{ return &callstackResult{
qpos: qpos,
target: target, target: target,
callstack: callstack, callstack: callstack,
}, nil }, nil
} }
type callstackResult struct { type callstackResult struct {
qpos *QueryPos
target *ssa.Function target *ssa.Function
callstack []pointer.CallSite callstack []pointer.CallSite
} }
func (r *callstackResult) display(printf printfFunc) { func (r *callstackResult) display(printf printfFunc) {
if r.callstack != nil { if r.callstack != nil {
printf(false, "Found a call path from root to %s", r.target) printf(r.qpos, "Found a call path from root to %s", r.target)
printf(r.target, "%s", r.target) printf(r.target, "%s", r.target)
for _, site := range r.callstack { for _, site := range r.callstack {
printf(site, "%s from %s", site.Description(), site.Caller().Func()) printf(site, "%s from %s", site.Description(), site.Caller().Func())

View File

@ -33,25 +33,25 @@ import (
// //
// All printed sets are sorted to ensure determinism. // All printed sets are sorted to ensure determinism.
// //
func describe(o *oracle) (queryResult, error) { func describe(o *Oracle, qpos *QueryPos) (queryResult, error) {
if false { // debugging if false { // debugging
o.fprintf(os.Stderr, o.queryPath[0], "you selected: %s %s", o.fprintf(os.Stderr, qpos.path[0], "you selected: %s %s",
importer.NodeDescription(o.queryPath[0]), pathToString2(o.queryPath)) importer.NodeDescription(qpos.path[0]), pathToString2(qpos.path))
} }
path, action := findInterestingNode(o.queryPkgInfo, o.queryPath) path, action := findInterestingNode(qpos.info, qpos.path)
switch action { switch action {
case actionExpr: case actionExpr:
return describeValue(o, path) return describeValue(o, qpos, path)
case actionType: case actionType:
return describeType(o, path) return describeType(o, qpos, path)
case actionPackage: case actionPackage:
return describePackage(o, path) return describePackage(o, qpos, path)
case actionStmt: case actionStmt:
return describeStmt(o, path) return describeStmt(o, qpos, path)
case actionUnknown: case actionUnknown:
return &describeUnknownResult{path[0]}, nil return &describeUnknownResult{path[0]}, nil
@ -305,11 +305,11 @@ func findInterestingNode(pkginfo *importer.PackageInfo, path []ast.Node) ([]ast.
// to the root of the AST is path. It may return a nil Value without // to the root of the AST is path. It may return a nil Value without
// an error to indicate the pointer analysis is not appropriate. // an error to indicate the pointer analysis is not appropriate.
// //
func ssaValueForIdent(o *oracle, obj types.Object, path []ast.Node) (ssa.Value, error) { func ssaValueForIdent(prog *ssa.Program, qinfo *importer.PackageInfo, obj types.Object, path []ast.Node) (ssa.Value, error) {
if obj, ok := obj.(*types.Var); ok { if obj, ok := obj.(*types.Var); ok {
pkg := o.prog.Package(o.queryPkgInfo.Pkg) pkg := prog.Package(qinfo.Pkg)
pkg.Build() pkg.Build()
if v := o.prog.VarValue(obj, pkg, path); v != nil { if v := prog.VarValue(obj, pkg, path); v != nil {
// Don't run pointer analysis on a ref to a const expression. // Don't run pointer analysis on a ref to a const expression.
if _, ok := v.(*ssa.Const); ok { if _, ok := v.(*ssa.Const); ok {
v = nil v = nil
@ -328,8 +328,8 @@ func ssaValueForIdent(o *oracle, obj types.Object, path []ast.Node) (ssa.Value,
// return a nil Value without an error to indicate the pointer // return a nil Value without an error to indicate the pointer
// analysis is not appropriate. // analysis is not appropriate.
// //
func ssaValueForExpr(o *oracle, path []ast.Node) (ssa.Value, error) { func ssaValueForExpr(prog *ssa.Program, qinfo *importer.PackageInfo, path []ast.Node) (ssa.Value, error) {
pkg := o.prog.Package(o.queryPkgInfo.Pkg) pkg := prog.Package(qinfo.Pkg)
pkg.SetDebugMode(true) pkg.SetDebugMode(true)
pkg.Build() pkg.Build()
@ -345,7 +345,7 @@ func ssaValueForExpr(o *oracle, path []ast.Node) (ssa.Value, error) {
return nil, fmt.Errorf("can't locate SSA Value for expression in %s", fn) return nil, fmt.Errorf("can't locate SSA Value for expression in %s", fn)
} }
func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) { func describeValue(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeValueResult, error) {
var expr ast.Expr var expr ast.Expr
var obj types.Object var obj types.Object
switch n := path[0].(type) { switch n := path[0].(type) {
@ -353,7 +353,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) {
// ambiguous ValueSpec containing multiple names // ambiguous ValueSpec containing multiple names
return nil, o.errorf(n, "multiple value specification") return nil, o.errorf(n, "multiple value specification")
case *ast.Ident: case *ast.Ident:
obj = o.queryPkgInfo.ObjectOf(n) obj = qpos.info.ObjectOf(n)
expr = n expr = n
case ast.Expr: case ast.Expr:
expr = n expr = n
@ -362,8 +362,8 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) {
return nil, o.errorf(n, "unexpected AST for expr: %T", n) return nil, o.errorf(n, "unexpected AST for expr: %T", n)
} }
typ := o.queryPkgInfo.TypeOf(expr) typ := qpos.info.TypeOf(expr)
constVal := o.queryPkgInfo.ValueOf(expr) constVal := qpos.info.ValueOf(expr)
// From this point on, we cannot fail with an error. // From this point on, we cannot fail with an error.
// Failure to run the pointer analysis will be reported later. // Failure to run the pointer analysis will be reported later.
@ -386,11 +386,11 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) {
var value ssa.Value var value ssa.Value
if obj != nil { if obj != nil {
// def/ref of func/var/const object // def/ref of func/var/const object
value, ptaErr = ssaValueForIdent(o, obj, path) value, ptaErr = ssaValueForIdent(o.prog, qpos.info, obj, path)
} else { } else {
// any other expression // any other expression
if o.queryPkgInfo.ValueOf(path[0].(ast.Expr)) == nil { // non-constant? if qpos.info.ValueOf(path[0].(ast.Expr)) == nil { // non-constant?
value, ptaErr = ssaValueForExpr(o, path) value, ptaErr = ssaValueForExpr(o.prog, qpos.info, path)
} }
} }
if value != nil { if value != nil {
@ -404,6 +404,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) {
} }
return &describeValueResult{ return &describeValueResult{
qpos: qpos,
expr: expr, expr: expr,
typ: typ, typ: typ,
constVal: constVal, constVal: constVal,
@ -414,7 +415,7 @@ func describeValue(o *oracle, path []ast.Node) (*describeValueResult, error) {
} }
// describePointer runs the pointer analysis of the selected SSA value. // describePointer runs the pointer analysis of the selected SSA value.
func describePointer(o *oracle, v ssa.Value, indirect bool) (ptrs []pointerResult, err error) { func describePointer(o *Oracle, v ssa.Value, indirect bool) (ptrs []pointerResult, err error) {
buildSSA(o) buildSSA(o)
// TODO(adonovan): don't run indirect pointer analysis on non-ptr-ptrlike types. // TODO(adonovan): don't run indirect pointer analysis on non-ptr-ptrlike types.
@ -454,6 +455,7 @@ type pointerResult struct {
} }
type describeValueResult struct { type describeValueResult struct {
qpos *QueryPos
expr ast.Expr // query node expr ast.Expr // query node
typ types.Type // type of expression typ types.Type // type of expression
constVal exact.Value // value of expression, if constant constVal exact.Value // value of expression, if constant
@ -524,7 +526,7 @@ func (r *describeValueResult) display(printf printfFunc) {
// reflect.Value expression. // reflect.Value expression.
if len(r.ptrs) > 0 { if len(r.ptrs) > 0 {
printf(false, "this %s may contain these dynamic types:", r.typ) printf(r.qpos, "this %s may contain these dynamic types:", r.typ)
for _, ptr := range r.ptrs { for _, ptr := range r.ptrs {
var obj types.Object var obj types.Object
if nt, ok := deref(ptr.typ).(*types.Named); ok { if nt, ok := deref(ptr.typ).(*types.Named); ok {
@ -538,15 +540,15 @@ func (r *describeValueResult) display(printf printfFunc) {
} }
} }
} else { } else {
printf(false, "this %s cannot contain any dynamic types.", r.typ) printf(r.qpos, "this %s cannot contain any dynamic types.", r.typ)
} }
} else { } else {
// Show labels for other expressions. // Show labels for other expressions.
if ptr := r.ptrs[0]; len(ptr.labels) > 0 { if ptr := r.ptrs[0]; len(ptr.labels) > 0 {
printf(false, "value may point to these labels:") printf(r.qpos, "value may point to these labels:")
printLabels(printf, ptr.labels, "\t") printLabels(printf, ptr.labels, "\t")
} else { } else {
printf(false, "value cannot point to anything.") printf(r.qpos, "value cannot point to anything.")
} }
} }
} }
@ -622,12 +624,12 @@ func printLabels(printf printfFunc, labels []*pointer.Label, prefix string) {
// ---- TYPE ------------------------------------------------------------ // ---- TYPE ------------------------------------------------------------
func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) { func describeType(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeTypeResult, error) {
var description string var description string
var t types.Type var t types.Type
switch n := path[0].(type) { switch n := path[0].(type) {
case *ast.Ident: case *ast.Ident:
t = o.queryPkgInfo.TypeOf(n) t = qpos.info.TypeOf(n)
switch t := t.(type) { switch t := t.(type) {
case *types.Basic: case *types.Basic:
description = "reference to built-in type " + t.String() description = "reference to built-in type " + t.String()
@ -642,7 +644,7 @@ func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) {
} }
case ast.Expr: case ast.Expr:
t = o.queryPkgInfo.TypeOf(n) t = qpos.info.TypeOf(n)
description = "type " + t.String() description = "type " + t.String()
default: default:
@ -654,7 +656,7 @@ func describeType(o *oracle, path []ast.Node) (*describeTypeResult, error) {
node: path[0], node: path[0],
description: description, description: description,
typ: t, typ: t,
methods: accessibleMethods(t, o.queryPkgInfo.Pkg), methods: accessibleMethods(t, qpos.info.Pkg),
}, nil }, nil
} }
@ -708,7 +710,7 @@ func (r *describeTypeResult) toJSON(res *json.Result, fset *token.FileSet) {
// ---- PACKAGE ------------------------------------------------------------ // ---- PACKAGE ------------------------------------------------------------
func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error) { func describePackage(o *Oracle, qpos *QueryPos, path []ast.Node) (*describePackageResult, error) {
var description string var description string
var pkg *types.Package var pkg *types.Package
switch n := path[0].(type) { switch n := path[0].(type) {
@ -724,12 +726,12 @@ func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error)
case *ast.Ident: case *ast.Ident:
if _, isDef := path[1].(*ast.File); isDef { if _, isDef := path[1].(*ast.File); isDef {
// e.g. package id // e.g. package id
pkg = o.queryPkgInfo.Pkg pkg = qpos.info.Pkg
description = fmt.Sprintf("definition of package %q", pkg.Path()) description = fmt.Sprintf("definition of package %q", pkg.Path())
} else { } else {
// e.g. import id // e.g. import id
// or id.F() // or id.F()
pkg = o.queryPkgInfo.ObjectOf(n).Pkg() pkg = qpos.info.ObjectOf(n).Pkg()
description = fmt.Sprintf("reference to package %q", pkg.Path()) description = fmt.Sprintf("reference to package %q", pkg.Path())
} }
@ -744,11 +746,11 @@ func describePackage(o *oracle, path []ast.Node) (*describePackageResult, error)
// Enumerate the accessible package members // Enumerate the accessible package members
// in lexicographic order. // in lexicographic order.
for _, name := range pkg.Scope().Names() { for _, name := range pkg.Scope().Names() {
if pkg == o.queryPkgInfo.Pkg || ast.IsExported(name) { if pkg == qpos.info.Pkg || ast.IsExported(name) {
mem := pkg.Scope().Lookup(name) mem := pkg.Scope().Lookup(name)
var methods []*types.Selection var methods []*types.Selection
if mem, ok := mem.(*types.TypeName); ok { if mem, ok := mem.(*types.TypeName); ok {
methods = accessibleMethods(mem.Type(), o.queryPkgInfo.Pkg) methods = accessibleMethods(mem.Type(), qpos.info.Pkg)
} }
members = append(members, &describeMember{ members = append(members, &describeMember{
mem, mem,
@ -878,11 +880,11 @@ func tokenOf(o types.Object) string {
// ---- STATEMENT ------------------------------------------------------------ // ---- STATEMENT ------------------------------------------------------------
func describeStmt(o *oracle, path []ast.Node) (*describeStmtResult, error) { func describeStmt(o *Oracle, qpos *QueryPos, path []ast.Node) (*describeStmtResult, error) {
var description string var description string
switch n := path[0].(type) { switch n := path[0].(type) {
case *ast.Ident: case *ast.Ident:
if o.queryPkgInfo.ObjectOf(n).Pos() == n.Pos() { if qpos.info.ObjectOf(n).Pos() == n.Pos() {
description = "labelled statement" description = "labelled statement"
} else { } else {
description = "reference to labelled statement" description = "reference to labelled statement"

View File

@ -26,9 +26,9 @@ import (
// these might be interesting. Perhaps group the results into three // these might be interesting. Perhaps group the results into three
// bands. // bands.
// //
func freevars(o *oracle) (queryResult, error) { func freevars(o *Oracle, qpos *QueryPos) (queryResult, error) {
file := o.queryPath[len(o.queryPath)-1] // the enclosing file file := qpos.path[len(qpos.path)-1] // the enclosing file
fileScope := o.queryPkgInfo.Scopes[file] fileScope := qpos.info.Scopes[file]
pkgScope := fileScope.Parent() pkgScope := fileScope.Parent()
// The id and sel functions return non-nil if they denote an // The id and sel functions return non-nil if they denote an
@ -49,7 +49,7 @@ func freevars(o *oracle) (queryResult, error) {
} }
id = func(n *ast.Ident) types.Object { id = func(n *ast.Ident) types.Object {
obj := o.queryPkgInfo.ObjectOf(n) obj := qpos.info.ObjectOf(n)
if obj == nil { if obj == nil {
return nil // TODO(adonovan): fix: this fails for *types.Label. return nil // TODO(adonovan): fix: this fails for *types.Label.
panic(o.errorf(n, "no types.Object for ast.Ident")) panic(o.errorf(n, "no types.Object for ast.Ident"))
@ -70,7 +70,7 @@ func freevars(o *oracle) (queryResult, error) {
if scope == fileScope || scope == pkgScope { if scope == fileScope || scope == pkgScope {
return nil // defined at file or package scope return nil // defined at file or package scope
} }
if o.startPos <= obj.Pos() && obj.Pos() <= o.endPos { if qpos.start <= obj.Pos() && obj.Pos() <= qpos.end {
return nil // defined within selection => not free return nil // defined within selection => not free
} }
return obj return obj
@ -82,7 +82,7 @@ func freevars(o *oracle) (queryResult, error) {
refsMap := make(map[string]freevarsRef) refsMap := make(map[string]freevarsRef)
// Visit all the identifiers in the selected ASTs. // Visit all the identifiers in the selected ASTs.
ast.Inspect(o.queryPath[0], func(n ast.Node) bool { ast.Inspect(qpos.path[0], func(n ast.Node) bool {
if n == nil { if n == nil {
return true // popping DFS stack return true // popping DFS stack
} }
@ -90,7 +90,7 @@ func freevars(o *oracle) (queryResult, error) {
// Is this node contained within the selection? // Is this node contained within the selection?
// (freevars permits inexact selections, // (freevars permits inexact selections,
// like two stmts in a block.) // like two stmts in a block.)
if o.startPos <= n.Pos() && n.End() <= o.endPos { if qpos.start <= n.Pos() && n.End() <= qpos.end {
var obj types.Object var obj types.Object
var prune bool var prune bool
switch n := n.(type) { switch n := n.(type) {
@ -119,7 +119,7 @@ func freevars(o *oracle) (queryResult, error) {
panic(obj) panic(obj)
} }
typ := o.queryPkgInfo.TypeOf(n.(ast.Expr)) typ := qpos.info.TypeOf(n.(ast.Expr))
ref := freevarsRef{kind, o.printNode(n), typ, obj} ref := freevarsRef{kind, o.printNode(n), typ, obj}
refsMap[ref.ref] = ref refsMap[ref.ref] = ref
@ -139,12 +139,14 @@ func freevars(o *oracle) (queryResult, error) {
sort.Sort(byRef(refs)) sort.Sort(byRef(refs))
return &freevarsResult{ return &freevarsResult{
qpos: qpos,
fset: o.prog.Fset, fset: o.prog.Fset,
refs: refs, refs: refs,
}, nil }, nil
} }
type freevarsResult struct { type freevarsResult struct {
qpos *QueryPos
fset *token.FileSet fset *token.FileSet
refs []freevarsRef refs []freevarsRef
} }
@ -158,9 +160,9 @@ type freevarsRef struct {
func (r *freevarsResult) display(printf printfFunc) { func (r *freevarsResult) display(printf printfFunc) {
if len(r.refs) == 0 { if len(r.refs) == 0 {
printf(false, "No free identifers.") printf(r.qpos, "No free identifiers.")
} else { } else {
printf(false, "Free identifers:") printf(r.qpos, "Free identifiers:")
for _, ref := range r.refs { for _, ref := range r.refs {
printf(ref.obj, "%s %s %s", ref.kind, ref.ref, ref.typ) printf(ref.obj, "%s %s %s", ref.kind, ref.ref, ref.typ)
} }

View File

@ -30,8 +30,8 @@ import (
// actually occur, with examples? (NB: this is not a conservative // actually occur, with examples? (NB: this is not a conservative
// answer due to ChangeInterface, i.e. subtyping among interfaces.) // answer due to ChangeInterface, i.e. subtyping among interfaces.)
// //
func implements(o *oracle) (queryResult, error) { func implements(o *Oracle, qpos *QueryPos) (queryResult, error) {
pkg := o.queryPkgInfo.Pkg pkg := qpos.info.Pkg
// Compute set of named interface/concrete types at package level. // Compute set of named interface/concrete types at package level.
var interfaces, concretes []*types.Named var interfaces, concretes []*types.Named

View File

@ -2,6 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package oracle contains the implementation of the oracle tool whose
// command-line is provided by code.google.com/p/go.tools/cmd/oracle.
//
// http://golang.org/s/oracle-design
// http://golang.org/s/oracle-user-manual
//
package oracle package oracle
// This file defines oracle.Query, the entry point for the oracle tool. // This file defines oracle.Query, the entry point for the oracle tool.
@ -33,16 +39,12 @@ import (
"code.google.com/p/go.tools/ssa" "code.google.com/p/go.tools/ssa"
) )
type oracle struct { // An Oracle holds the program state required for one or more queries.
type Oracle struct {
out io.Writer // standard output out io.Writer // standard output
prog *ssa.Program // the SSA program [only populated if need&SSA] prog *ssa.Program // the SSA program [only populated if need&SSA]
config pointer.Config // pointer analysis configuration config pointer.Config // pointer analysis configuration
// need&(Pos|ExactPos):
startPos, endPos token.Pos // source extent of query
queryPkgInfo *importer.PackageInfo // type info for the queried package
queryPath []ast.Node // AST path from query node to root of ast.File
// need&AllTypeInfo // need&AllTypeInfo
typeInfo map[*types.Package]*importer.PackageInfo // type info for all ASTs in the program typeInfo map[*types.Package]*importer.PackageInfo // type info for all ASTs in the program
@ -53,30 +55,42 @@ type oracle struct {
// //
// Typed ASTs for the whole program are always constructed // Typed ASTs for the whole program are always constructed
// transiently; they are retained only for the queried package unless // transiently; they are retained only for the queried package unless
// AllTypeInfo is set. // needAllTypeInfo is set.
const ( const (
Pos = 1 << iota // needs a position needPos = 1 << iota // needs a position
ExactPos // needs an exact AST selection; implies Pos needExactPos // needs an exact AST selection; implies needPos
AllTypeInfo // needs to retain type info for all ASTs in the program needAllTypeInfo // needs to retain type info for all ASTs in the program
SSA // needs ssa.Packages for whole program needSSA // needs ssa.Packages for whole program
PTA = SSA // needs pointer analysis needSSADebug // needs debug info for ssa.Packages
needPTA = needSSA // needs pointer analysis
needAll = -1 // needs everything (e.g. a sequence of queries)
) )
type modeInfo struct { type modeInfo struct {
name string
needs int needs int
impl func(*oracle) (queryResult, error) impl func(*Oracle, *QueryPos) (queryResult, error)
} }
var modes = map[string]modeInfo{ var modes = []*modeInfo{
"callees": modeInfo{PTA | ExactPos, callees}, {"callees", needPTA | needExactPos, callees},
"callers": modeInfo{PTA | Pos, callers}, {"callers", needPTA | needPos, callers},
"callgraph": modeInfo{PTA, callgraph}, {"callgraph", needPTA, callgraph},
"callstack": modeInfo{PTA | Pos, callstack}, {"callstack", needPTA | needPos, callstack},
"describe": modeInfo{PTA | ExactPos, describe}, {"describe", needPTA | needSSADebug | needExactPos, describe},
"freevars": modeInfo{Pos, freevars}, {"freevars", needPos, freevars},
"implements": modeInfo{Pos, implements}, {"implements", needPos, implements},
"peers": modeInfo{PTA | Pos, peers}, {"peers", needPTA | needSSADebug | needPos, peers},
"referrers": modeInfo{AllTypeInfo | Pos, referrers}, {"referrers", needAllTypeInfo | needPos, referrers},
}
func findMode(mode string) *modeInfo {
for _, m := range modes {
if m.name == mode {
return m
}
}
return nil
} }
type printfFunc func(pos interface{}, format string, args ...interface{}) type printfFunc func(pos interface{}, format string, args ...interface{})
@ -93,6 +107,17 @@ type warning struct {
args []interface{} args []interface{}
} }
// A QueryPos represents the position provided as input to a query:
// a textual extent in the program's source code, the AST node it
// corresponds to, and the package to which it belongs.
// Instances are created by ParseQueryPos.
//
type QueryPos struct {
start, end token.Pos // source extent of query
info *importer.PackageInfo // type info for the queried package
path []ast.Node // AST path from query node to root of ast.File
}
// A Result encapsulates the result of an oracle.Query. // A Result encapsulates the result of an oracle.Query.
// //
// Result instances implement the json.Marshaler interface, i.e. they // Result instances implement the json.Marshaler interface, i.e. they
@ -118,41 +143,91 @@ func (res *Result) MarshalJSON() ([]byte, error) {
return encjson.Marshal(resj) return encjson.Marshal(resj)
} }
// Query runs the oracle. // Query runs a single oracle query.
//
// args specify the main package in importer.CreatePackageFromArgs syntax. // args specify the main package in importer.CreatePackageFromArgs syntax.
// mode is the query mode ("callers", etc). // mode is the query mode ("callers", etc).
// pos is the selection in parseQueryPos() syntax.
// ptalog is the (optional) pointer-analysis log file. // ptalog is the (optional) pointer-analysis log file.
// buildContext is the optional configuration for locating packages. // buildContext is the go/build configuration for locating packages.
//
// Clients that intend to perform multiple queries against the same
// analysis scope should use this pattern instead:
//
// imp := importer.New(&importer.Config{Build: buildContext})
// o, err := oracle.New(imp, args, nil)
// if err != nil { ... }
// for ... {
// qpos, err := oracle.ParseQueryPos(imp, pos, needExact)
// if err != nil { ... }
//
// res, err := o.Query(mode, qpos)
// if err != nil { ... }
//
// // use res
// }
//
// TODO(adonovan): the ideal 'needsExact' parameter for ParseQueryPos
// depends on the query mode; how should we expose this?
// //
func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *build.Context) (*Result, error) { func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *build.Context) (*Result, error) {
minfo, ok := modes[mode] minfo := findMode(mode)
if !ok { if minfo == nil {
return nil, fmt.Errorf("invalid mode type: %q", mode) return nil, fmt.Errorf("invalid mode type: %q", mode)
} }
imp := importer.New(&importer.Config{Build: buildContext}) imp := importer.New(&importer.Config{Build: buildContext})
o := &oracle{ o, err := New(imp, args, ptalog)
if err != nil {
return nil, err
}
// Phase timing diagnostics.
// TODO(adonovan): needs more work.
// if false {
// defer func() {
// fmt.Println()
// for name, duration := range o.timers {
// fmt.Printf("# %-30s %s\n", name, duration)
// }
// }()
// }
var qpos *QueryPos
if minfo.needs&(needPos|needExactPos) != 0 {
var err error
qpos, err = ParseQueryPos(imp, pos, minfo.needs&needExactPos != 0)
if err != nil {
return nil, err
}
}
// SSA is built and we have the QueryPos.
// Release the other ASTs and type info to the GC.
imp = nil
return o.query(minfo, qpos)
}
// New constructs a new Oracle that can be used for a sequence of queries.
//
// imp will be used to load source code for imported packages.
// It must not yet have loaded any packages.
//
// args specify the main package in importer.CreatePackageFromArgs syntax.
//
// ptalog is the (optional) pointer-analysis log file.
//
func New(imp *importer.Importer, args []string, ptalog io.Writer) (*Oracle, error) {
return newOracle(imp, args, ptalog, needAll)
}
func newOracle(imp *importer.Importer, args []string, ptalog io.Writer, needs int) (*Oracle, error) {
o := &Oracle{
prog: ssa.NewProgram(imp.Fset, 0), prog: ssa.NewProgram(imp.Fset, 0),
timers: make(map[string]time.Duration), timers: make(map[string]time.Duration),
} }
o.config.Log = ptalog o.config.Log = ptalog
var res Result
o.config.Warn = func(pos token.Pos, format string, args ...interface{}) {
res.warnings = append(res.warnings, warning{pos, format, args})
}
// Phase timing diagnostics.
if false {
defer func() {
fmt.Println()
for name, duration := range o.timers {
fmt.Printf("# %-30s %s\n", name, duration)
}
}()
}
// Load/parse/type-check program from args. // Load/parse/type-check program from args.
start := time.Now() start := time.Now()
initialPkgInfos, args, err := imp.LoadInitialPackages(args) initialPkgInfos, args, err := imp.LoadInitialPackages(args)
@ -165,7 +240,7 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil
o.timers["load/parse/type"] = time.Since(start) o.timers["load/parse/type"] = time.Since(start)
// Retain type info for all ASTs in the program. // Retain type info for all ASTs in the program.
if minfo.needs&AllTypeInfo != 0 { if needs&needAllTypeInfo != 0 {
m := make(map[*types.Package]*importer.PackageInfo) m := make(map[*types.Package]*importer.PackageInfo)
for _, p := range imp.AllPackages() { for _, p := range imp.AllPackages() {
m[p.Pkg] = p m[p.Pkg] = p
@ -173,32 +248,13 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil
o.typeInfo = m o.typeInfo = m
} }
// Parse the source query position.
if minfo.needs&(Pos|ExactPos) != 0 {
var err error
o.startPos, o.endPos, err = parseQueryPos(o.prog.Fset, pos)
if err != nil {
return nil, err
}
var exact bool
o.queryPkgInfo, o.queryPath, exact = imp.PathEnclosingInterval(o.startPos, o.endPos)
if o.queryPath == nil {
return nil, o.errorf(false, "no syntax here")
}
if minfo.needs&ExactPos != 0 && !exact {
return nil, o.errorf(o.queryPath[0], "ambiguous selection within %s",
importer.NodeDescription(o.queryPath[0]))
}
}
// Create SSA package for the initial package and its dependencies. // Create SSA package for the initial package and its dependencies.
if minfo.needs&SSA != 0 { if needs&needSSA != 0 {
start = time.Now() start = time.Now()
// Create SSA packages. // Create SSA packages.
if err := o.prog.CreatePackages(imp); err != nil { if err := o.prog.CreatePackages(imp); err != nil {
return nil, o.errorf(false, "%s", err) return nil, o.errorf(nil, "%s", err)
} }
// Initial packages (specified on command line) // Initial packages (specified on command line)
@ -211,34 +267,66 @@ func Query(args []string, mode, pos string, ptalog io.Writer, buildContext *buil
// should build a single synthetic testmain package, // should build a single synthetic testmain package,
// not synthetic main functions to many packages. // not synthetic main functions to many packages.
if initialPkg.CreateTestMainFunction() == nil { if initialPkg.CreateTestMainFunction() == nil {
return nil, o.errorf(false, "analysis scope has no main() entry points") return nil, o.errorf(nil, "analysis scope has no main() entry points")
} }
} }
o.config.Mains = append(o.config.Mains, initialPkg) o.config.Mains = append(o.config.Mains, initialPkg)
} }
// Query package. if needs&needSSADebug != 0 {
if o.queryPkgInfo != nil { for _, pkg := range o.prog.AllPackages() {
pkg := o.prog.Package(o.queryPkgInfo.Pkg) pkg.SetDebugMode(true)
pkg.SetDebugMode(true) }
pkg.Build()
} }
o.timers["SSA-create"] = time.Since(start) o.timers["SSA-create"] = time.Since(start)
} }
// SSA is built and we have query{Path,PkgInfo}. return o, nil
// Release the other ASTs and type info to the GC. }
imp = nil
res.q, err = minfo.impl(o) // Query runs the query of the specified mode and selection.
func (o *Oracle) Query(mode string, qpos *QueryPos) (*Result, error) {
minfo := findMode(mode)
if minfo == nil {
return nil, fmt.Errorf("invalid mode type: %q", mode)
}
return o.query(minfo, qpos)
}
func (o *Oracle) query(minfo *modeInfo, qpos *QueryPos) (*Result, error) {
res := &Result{
mode: minfo.name,
fset: o.prog.Fset,
fprintf: o.fprintf, // captures o.prog, o.{start,end}Pos for later printing
}
o.config.Warn = func(pos token.Pos, format string, args ...interface{}) {
res.warnings = append(res.warnings, warning{pos, format, args})
}
var err error
res.q, err = minfo.impl(o, qpos)
if err != nil { if err != nil {
return nil, err return nil, err
} }
res.mode = mode return res, nil
res.fset = o.prog.Fset }
res.fprintf = o.fprintf // captures o.prog, o.{start,end}Pos for later printing
return &res, nil // ParseQueryPos parses the source query position pos.
// If needExact, it must identify a single AST subtree.
//
func ParseQueryPos(imp *importer.Importer, pos string, needExact bool) (*QueryPos, error) {
start, end, err := parseQueryPos(imp.Fset, pos)
if err != nil {
return nil, err
}
info, path, exact := imp.PathEnclosingInterval(start, end)
if path == nil {
return nil, errors.New("no syntax here")
}
if needExact && !exact {
return nil, fmt.Errorf("ambiguous selection within %s", importer.NodeDescription(path[0]))
}
return &QueryPos{start, end, info, path}, nil
} }
// WriteTo writes the oracle query result res to out in a compiler diagnostic format. // WriteTo writes the oracle query result res to out in a compiler diagnostic format.
@ -262,7 +350,7 @@ func (res *Result) WriteTo(out io.Writer) {
// buildSSA constructs the SSA representation of Go-source function bodies. // buildSSA constructs the SSA representation of Go-source function bodies.
// Not needed in simpler modes, e.g. freevars. // Not needed in simpler modes, e.g. freevars.
// //
func buildSSA(o *oracle) { func buildSSA(o *Oracle) {
start := time.Now() start := time.Now()
o.prog.BuildAll() o.prog.BuildAll()
o.timers["SSA-build"] = time.Since(start) o.timers["SSA-build"] = time.Since(start)
@ -271,7 +359,7 @@ func buildSSA(o *oracle) {
// ptrAnalysis runs the pointer analysis and returns the synthetic // ptrAnalysis runs the pointer analysis and returns the synthetic
// root of the callgraph. // root of the callgraph.
// //
func ptrAnalysis(o *oracle) pointer.CallGraphNode { func ptrAnalysis(o *Oracle) pointer.CallGraphNode {
start := time.Now() start := time.Now()
root := pointer.Analyze(&o.config) root := pointer.Analyze(&o.config)
o.timers["pointer analysis"] = time.Since(start) o.timers["pointer analysis"] = time.Since(start)
@ -399,15 +487,14 @@ func deref(typ types.Type) types.Type {
// - an ast.Node, denoting an interval // - an ast.Node, denoting an interval
// - anything with a Pos() method: // - anything with a Pos() method:
// ssa.Member, ssa.Value, ssa.Instruction, types.Object, pointer.Label, etc. // ssa.Member, ssa.Value, ssa.Instruction, types.Object, pointer.Label, etc.
// - a bool, meaning the extent [o.startPos, o.endPos) of the user's query. // - a QueryPos, denoting the extent of the user's query.
// (the value is ignored)
// - nil, meaning no position at all. // - nil, meaning no position at all.
// //
// The output format is is compatible with the 'gnu' // The output format is is compatible with the 'gnu'
// compilation-error-regexp in Emacs' compilation mode. // compilation-error-regexp in Emacs' compilation mode.
// TODO(adonovan): support other editors. // TODO(adonovan): support other editors.
// //
func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...interface{}) { func (o *Oracle) fprintf(w io.Writer, pos interface{}, format string, args ...interface{}) {
var start, end token.Pos var start, end token.Pos
switch pos := pos.(type) { switch pos := pos.(type) {
case ast.Node: case ast.Node:
@ -421,9 +508,9 @@ func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...in
}: }:
start = pos.Pos() start = pos.Pos()
end = start end = start
case bool: case *QueryPos:
start = o.startPos start = pos.start
end = o.endPos end = pos.end
case nil: case nil:
// no-op // no-op
default: default:
@ -448,14 +535,14 @@ func (o *oracle) fprintf(w io.Writer, pos interface{}, format string, args ...in
} }
// errorf is like fprintf, but returns a formatted error string. // errorf is like fprintf, but returns a formatted error string.
func (o *oracle) errorf(pos interface{}, format string, args ...interface{}) error { func (o *Oracle) errorf(pos interface{}, format string, args ...interface{}) error {
var buf bytes.Buffer var buf bytes.Buffer
o.fprintf(&buf, pos, format, args...) o.fprintf(&buf, pos, format, args...)
return errors.New(buf.String()) return errors.New(buf.String())
} }
// printNode returns the pretty-printed syntax of n. // printNode returns the pretty-printed syntax of n.
func (o *oracle) printNode(n ast.Node) string { func (o *Oracle) printNode(n ast.Node) string {
var buf bytes.Buffer var buf bytes.Buffer
printer.Fprint(&buf, o.prog.Fset, n) printer.Fprint(&buf, o.prog.Fset, n)
return buf.String() return buf.String()

View File

@ -44,6 +44,7 @@ import (
"strings" "strings"
"testing" "testing"
"code.google.com/p/go.tools/importer"
"code.google.com/p/go.tools/oracle" "code.google.com/p/go.tools/oracle"
) )
@ -249,3 +250,53 @@ func TestOracle(t *testing.T) {
} }
} }
} }
func TestMultipleQueries(t *testing.T) {
// Importer
var buildContext = build.Default
buildContext.GOPATH = "testdata"
imp := importer.New(&importer.Config{Build: &buildContext})
// Oracle
filename := "testdata/src/main/multi.go"
o, err := oracle.New(imp, []string{filename}, nil)
if err != nil {
t.Fatalf("oracle.New failed: %s", err)
}
// QueryPos
pos := filename + ":#54,#58"
qpos, err := oracle.ParseQueryPos(imp, pos, true)
if err != nil {
t.Fatalf("oracle.ParseQueryPos(%q) failed: %s", pos, err)
}
// SSA is built and we have the QueryPos.
// Release the other ASTs and type info to the GC.
imp = nil
// Run different query moes on same scope and selection.
out := new(bytes.Buffer)
for _, mode := range [...]string{"callers", "describe", "freevars"} {
res, err := o.Query(mode, qpos)
if err != nil {
t.Errorf("(*oracle.Oracle).Query(%q) failed: %s", pos, err)
}
capture := new(bytes.Buffer) // capture standard output
res.WriteTo(capture)
for _, line := range strings.Split(capture.String(), "\n") {
fmt.Fprintf(out, "%s\n", stripLocation(line))
}
}
want := `multi.f is called from these 1 sites:
static function call from multi.main
function call (or conversion) of type ()
Free identifiers:
var x int
`
if got := out.String(); got != want {
t.Errorf("Query output differs; want <<%s>>, got <<%s>>\n", want, got)
}
}

View File

@ -22,10 +22,10 @@ import (
// TODO(adonovan): permit the user to query based on a MakeChan (not send/recv), // TODO(adonovan): permit the user to query based on a MakeChan (not send/recv),
// or the implicit receive in "for v := range ch". // or the implicit receive in "for v := range ch".
// //
func peers(o *oracle) (queryResult, error) { func peers(o *Oracle, qpos *QueryPos) (queryResult, error) {
arrowPos := findArrow(o) arrowPos := findArrow(qpos)
if arrowPos == token.NoPos { if arrowPos == token.NoPos {
return nil, o.errorf(o.queryPath[0], "there is no send/receive here") return nil, o.errorf(qpos.path[0], "there is no send/receive here")
} }
buildSSA(o) buildSSA(o)
@ -111,8 +111,8 @@ func peers(o *oracle) (queryResult, error) {
// findArrow returns the position of the enclosing send/receive op // findArrow returns the position of the enclosing send/receive op
// (<-) for the query position, or token.NoPos if not found. // (<-) for the query position, or token.NoPos if not found.
// //
func findArrow(o *oracle) token.Pos { func findArrow(qpos *QueryPos) token.Pos {
for _, n := range o.queryPath { for _, n := range qpos.path {
switch n := n.(type) { switch n := n.(type) {
case *ast.UnaryExpr: case *ast.UnaryExpr:
if n.Op == token.ARROW { if n.Op == token.ARROW {

View File

@ -16,16 +16,16 @@ import (
// Referrers reports all identifiers that resolve to the same object // Referrers reports all identifiers that resolve to the same object
// as the queried identifier, within any package in the analysis scope. // as the queried identifier, within any package in the analysis scope.
// //
func referrers(o *oracle) (queryResult, error) { func referrers(o *Oracle, qpos *QueryPos) (queryResult, error) {
id, _ := o.queryPath[0].(*ast.Ident) id, _ := qpos.path[0].(*ast.Ident)
if id == nil { if id == nil {
return nil, o.errorf(false, "no identifier here") return nil, o.errorf(qpos, "no identifier here")
} }
obj := o.queryPkgInfo.ObjectOf(id) obj := qpos.info.ObjectOf(id)
if obj == nil { if obj == nil {
// Happens for y in "switch y := x.(type)", but I think that's all. // Happens for y in "switch y := x.(type)", but I think that's all.
return nil, o.errorf(false, "no object for identifier") return nil, o.errorf(qpos, "no object for identifier")
} }
// Iterate over all go/types' resolver facts for the entire program. // Iterate over all go/types' resolver facts for the entire program.

View File

@ -80,7 +80,6 @@ Error: this is a type conversion, not a function call
-------- @callees callees-err-bad-selection -------- -------- @callees callees-err-bad-selection --------
Error: ambiguous selection within function call (or conversion) Error: ambiguous selection within function call (or conversion)
-------- @callees callees-err-deadcode1 -------- -------- @callees callees-err-deadcode1 --------
Error: this call site is unreachable in this analysis Error: this call site is unreachable in this analysis

View File

@ -1,11 +1,11 @@
-------- @freevars fv1 -------- -------- @freevars fv1 --------
Free identifers: Free identifiers:
type C main.C type C main.C
const exp int const exp int
var x int var x int
-------- @freevars fv2 -------- -------- @freevars fv2 --------
Free identifers: Free identifiers:
var s.t.a int var s.t.a int
var s.t.b int var s.t.b int
var s.x int var s.x int
@ -13,6 +13,6 @@ var x int
var y int32 var y int32
-------- @freevars fv3 -------- -------- @freevars fv3 --------
Free identifers: Free identifiers:
var x int var x int

13
oracle/testdata/src/main/multi.go vendored Normal file
View File

@ -0,0 +1,13 @@
package multi
func g(x int) {
}
func f() {
x := 1
g(x) // "g(x)" is the selection for multiple queries
}
func main() {
f()
}