diff --git a/call/util.go b/call/util.go index d4f39236..855e3701 100644 --- a/call/util.go +++ b/call/util.go @@ -20,7 +20,7 @@ package call // Add a utility function to eliminate all context from a call graph. // CalleesOf returns a new set containing all direct callees of the -// caller node in call graph g. +// caller node. // func CalleesOf(caller GraphNode) map[GraphNode]bool { callees := make(map[GraphNode]bool) @@ -31,23 +31,33 @@ func CalleesOf(caller GraphNode) map[GraphNode]bool { } // GraphVisitEdges visits all the edges in graph g in depth-first order. -// The edge function is called for each edge in postorder. +// The edge function is called for each edge in postorder. If it +// returns non-nil, visitation stops and GraphVisitEdges returns that +// value. // -func GraphVisitEdges(g Graph, edge func(Edge)) { +func GraphVisitEdges(g Graph, edge func(Edge) error) error { seen := make(map[GraphNode]bool) - var visit func(n GraphNode) - visit = func(n GraphNode) { + var visit func(n GraphNode) error + visit = func(n GraphNode) error { if !seen[n] { seen[n] = true for _, e := range n.Edges() { - visit(e.Callee) - edge(e) + if err := visit(e.Callee); err != nil { + return err + } + if err := edge(e); err != nil { + return err + } } } + return nil } for _, n := range g.Nodes() { - visit(n) + if err := visit(n); err != nil { + return err + } } + return nil } // PathSearch finds an arbitrary path starting at node start and diff --git a/oracle/callees.go b/oracle/callees.go index 3362bb5a..3d6a3b6b 100644 --- a/oracle/callees.go +++ b/oracle/callees.go @@ -12,7 +12,6 @@ import ( "code.google.com/p/go.tools/go/types" "code.google.com/p/go.tools/oracle/serial" - "code.google.com/p/go.tools/pointer" "code.google.com/p/go.tools/ssa" ) @@ -22,14 +21,19 @@ import ( // TODO(adonovan): if a callee is a wrapper, show the callee's callee. // func callees(o *Oracle, qpos *QueryPos) (queryResult, error) { + pkg := o.prog.Package(qpos.info.Pkg) + if pkg == nil { + return nil, fmt.Errorf("no SSA package") + } + // Determine the enclosing call for the specified position. - var call *ast.CallExpr + var e *ast.CallExpr for _, n := range qpos.path { - if call, _ = n.(*ast.CallExpr); call != nil { + if e, _ = n.(*ast.CallExpr); e != nil { break } } - if call == nil { + if e == nil { return nil, fmt.Errorf("there is no function call here") } // TODO(adonovan): issue an error if the call is "too far @@ -37,12 +41,12 @@ func callees(o *Oracle, qpos *QueryPos) (queryResult, error) { // not what the user intended. // Reject type conversions. - if qpos.info.IsType(call.Fun) { + if qpos.info.IsType(e.Fun) { return nil, fmt.Errorf("this is a type conversion, not a function call") } // Reject calls to built-ins. - if id, ok := unparen(call.Fun).(*ast.Ident); ok { + if id, ok := unparen(e.Fun).(*ast.Ident); ok { if b, ok := qpos.info.ObjectOf(id).(*types.Builtin); ok { return nil, fmt.Errorf("this is a call to the built-in '%s' operator", b.Name()) } @@ -50,64 +54,68 @@ func callees(o *Oracle, qpos *QueryPos) (queryResult, error) { buildSSA(o) - // Compute the subgraph of the callgraph for callsite(s) - // arising from 'call'. There may be more than one if its - // enclosing function was treated context-sensitively. - // (Or zero if it was in dead code.) - // - // The presence of a key indicates this call site is - // interesting even if the value is nil. - querySites := make(map[pointer.CallSite][]pointer.CallGraphNode) - var arbitrarySite pointer.CallSite - o.config.CallSite = func(site pointer.CallSite) { - if site.Pos() == call.Lparen { - // Not a no-op! Ensures key is - // present even if value is nil: - querySites[site] = querySites[site] - arbitrarySite = site - } - } - o.config.Call = func(site pointer.CallSite, callee pointer.CallGraphNode) { - if targets, ok := querySites[site]; ok { - querySites[site] = append(targets, callee) - } - } - ptrAnalysis(o) - - if arbitrarySite == nil { - return nil, fmt.Errorf("this call site is unreachable in this analysis") + // Ascertain calling function and call site. + callerFn := ssa.EnclosingFunction(pkg, qpos.path) + if callerFn == nil { + return nil, fmt.Errorf("no SSA function built for this location (dead code?)") } - // Compute union of callees across all contexts. - funcsMap := make(map[*ssa.Function]bool) - for _, callees := range querySites { - for _, callee := range callees { - funcsMap[callee.Func()] = true + o.config.BuildCallGraph = true + callgraph := ptrAnalysis(o).CallGraph + + // Find the call site and all edges from it. + var site ssa.CallInstruction + calleesMap := make(map[*ssa.Function]bool) + for _, n := range callgraph.Nodes() { + if n.Func() == callerFn { + if site == nil { + // First node for callerFn: identify the site. + for _, s := range n.Sites() { + if s.Pos() == e.Lparen { + site = s + break + } + } + if site == nil { + return nil, fmt.Errorf("this call site is unreachable in this analysis") + } + } + + for _, edge := range n.Edges() { + if edge.Site == site { + calleesMap[edge.Callee.Func()] = true + } + } } } - funcs := make([]*ssa.Function, 0, len(funcsMap)) - for f := range funcsMap { + if site == nil { + return nil, fmt.Errorf("this function is unreachable in this analysis") + } + + // Discard context, de-duplicate and sort. + funcs := make([]*ssa.Function, 0, len(calleesMap)) + for f := range calleesMap { funcs = append(funcs, f) } sort.Sort(byFuncPos(funcs)) return &calleesResult{ - site: arbitrarySite, + site: site, funcs: funcs, }, nil } type calleesResult struct { - site pointer.CallSite + site ssa.CallInstruction funcs []*ssa.Function } func (r *calleesResult) display(printf printfFunc) { if len(r.funcs) == 0 { // dynamic call on a provably nil func/interface - printf(r.site, "%s on nil value", r.site.Description()) + printf(r.site, "%s on nil value", r.site.Common().Description()) } else { - printf(r.site, "this %s dispatches to:", r.site.Description()) + printf(r.site, "this %s dispatches to:", r.site.Common().Description()) for _, callee := range r.funcs { printf(callee, "\t%s", callee) } @@ -117,7 +125,7 @@ func (r *calleesResult) display(printf printfFunc) { func (r *calleesResult) toSerial(res *serial.Result, fset *token.FileSet) { j := &serial.Callees{ Pos: fset.Position(r.site.Pos()).String(), - Desc: r.site.Description(), + Desc: r.site.Common().Description(), } for _, callee := range r.funcs { j.Callees = append(j.Callees, &serial.CalleesItem{ diff --git a/oracle/callers.go b/oracle/callers.go index bba19d21..e773a2d6 100644 --- a/oracle/callers.go +++ b/oracle/callers.go @@ -8,8 +8,8 @@ import ( "fmt" "go/token" + "code.google.com/p/go.tools/call" "code.google.com/p/go.tools/oracle/serial" - "code.google.com/p/go.tools/pointer" "code.google.com/p/go.tools/ssa" ) @@ -36,54 +36,57 @@ func callers(o *Oracle, qpos *QueryPos) (queryResult, error) { // Run the pointer analysis, recording each // call found to originate from target. - var calls []pointer.CallSite - o.config.Call = func(site pointer.CallSite, callee pointer.CallGraphNode) { - if callee.Func() == target { - calls = append(calls, site) + o.config.BuildCallGraph = true + callgraph := ptrAnalysis(o).CallGraph + var edges []call.Edge + call.GraphVisitEdges(callgraph, func(edge call.Edge) error { + if edge.Callee.Func() == target { + edges = append(edges, edge) } - } - // TODO(adonovan): sort calls, to ensure test determinism. - - root := ptrAnalysis(o) + return nil + }) + // TODO(adonovan): sort + dedup calls to ensure test determinism. return &callersResult{ - target: target, - root: root, - calls: calls, + target: target, + callgraph: callgraph, + edges: edges, }, nil } type callersResult struct { - target *ssa.Function - root pointer.CallGraphNode - calls []pointer.CallSite + target *ssa.Function + callgraph call.Graph + edges []call.Edge } func (r *callersResult) display(printf printfFunc) { - if r.calls == nil { + root := r.callgraph.Root() + if r.edges == nil { printf(r.target, "%s is not reachable in this program.", r.target) } else { - printf(r.target, "%s is called from these %d sites:", r.target, len(r.calls)) - for _, site := range r.calls { - if site.Caller() == r.root { + printf(r.target, "%s is called from these %d sites:", r.target, len(r.edges)) + for _, edge := range r.edges { + if edge.Caller == root { printf(r.target, "the root of the call graph") } else { - printf(site, "\t%s from %s", site.Description(), site.Caller().Func()) + printf(edge.Site, "\t%s from %s", edge.Site.Common().Description(), edge.Caller.Func()) } } } } func (r *callersResult) toSerial(res *serial.Result, fset *token.FileSet) { + root := r.callgraph.Root() var callers []serial.Caller - for _, site := range r.calls { + for _, edge := range r.edges { var c serial.Caller - c.Caller = site.Caller().Func().String() - if site.Caller() == r.root { + c.Caller = edge.Caller.Func().String() + if edge.Caller == root { c.Desc = "synthetic call" } else { - c.Pos = fset.Position(site.Pos()).String() - c.Desc = site.Description() + c.Pos = fset.Position(edge.Site.Pos()).String() + c.Desc = edge.Site.Common().Description() } callers = append(callers, c) } diff --git a/oracle/callgraph.go b/oracle/callgraph.go index 3598fd84..590fa640 100644 --- a/oracle/callgraph.go +++ b/oracle/callgraph.go @@ -8,8 +8,8 @@ import ( "go/token" "strings" + "code.google.com/p/go.tools/call" "code.google.com/p/go.tools/oracle/serial" - "code.google.com/p/go.tools/pointer" ) // callgraph displays the entire callgraph of the current program. @@ -23,42 +23,24 @@ import ( // TODO(adonovan): add an option to project away context sensitivity. // The callgraph API should provide this feature. // +// TODO(adonovan): add an option to partition edges by call site. +// // TODO(adonovan): elide nodes for synthetic functions? // func callgraph(o *Oracle, _ *QueryPos) (queryResult, error) { buildSSA(o) // Run the pointer analysis and build the complete callgraph. - callgraph := make(pointer.CallGraph) - o.config.Call = callgraph.AddEdge - root := ptrAnalysis(o) - - // Assign (preorder) numbers to all the callgraph nodes. - // TODO(adonovan): the callgraph API should do this for us. - // (Actually, it does have unique numbers under the hood.) - numbering := make(map[pointer.CallGraphNode]int) - var number func(cgn pointer.CallGraphNode) - number = func(cgn pointer.CallGraphNode) { - if _, ok := numbering[cgn]; !ok { - numbering[cgn] = len(numbering) - for callee := range callgraph[cgn] { - number(callee) - } - } - } - number(root) + o.config.BuildCallGraph = true + ptares := ptrAnalysis(o) return &callgraphResult{ - root: root, - callgraph: callgraph, - numbering: numbering, + callgraph: ptares.CallGraph, }, nil } type callgraphResult struct { - root pointer.CallGraphNode - callgraph pointer.CallGraph - numbering map[pointer.CallGraphNode]int + callgraph call.Graph } func (r *callgraphResult) display(printf printfFunc) { @@ -71,34 +53,41 @@ Some nodes may appear multiple times due to context-sensitive treatment of some calls. `) - // TODO(adonovan): compute the numbers as we print; right now - // it depends on map iteration so it's arbitrary,which is ugly. - seen := make(map[pointer.CallGraphNode]bool) - var print func(cgn pointer.CallGraphNode, indent int) - print = func(cgn pointer.CallGraphNode, indent int) { - n := r.numbering[cgn] - if !seen[cgn] { - seen[cgn] = true - printf(cgn.Func(), "%d\t%s%s", n, strings.Repeat(" ", indent), cgn.Func()) - for callee := range r.callgraph[cgn] { + seen := make(map[call.GraphNode]int) + var print func(cgn call.GraphNode, indent int) + print = func(cgn call.GraphNode, indent int) { + fn := cgn.Func() + if num, ok := seen[cgn]; !ok { + num = len(seen) + seen[cgn] = num + printf(fn, "%d\t%s%s", num, strings.Repeat(" ", indent), fn) + // Don't use Edges(), which distinguishes callees by call site. + for callee := range call.CalleesOf(cgn) { print(callee, indent+1) } } else { - printf(cgn.Func(), "\t%s%s (%d)", strings.Repeat(" ", indent), cgn.Func(), n) + printf(fn, "\t%s%s (%d)", strings.Repeat(" ", indent), fn, num) } } - print(r.root, 0) + print(r.callgraph.Root(), 0) } func (r *callgraphResult) toSerial(res *serial.Result, fset *token.FileSet) { - cg := make([]serial.CallGraph, len(r.numbering)) - for n, i := range r.numbering { + nodes := r.callgraph.Nodes() + + numbering := make(map[call.GraphNode]int) + for i, n := range nodes { + numbering[n] = i + } + + cg := make([]serial.CallGraph, len(nodes)) + for i, n := range nodes { j := &cg[i] fn := n.Func() j.Name = fn.String() j.Pos = fset.Position(fn.Pos()).String() - for callee := range r.callgraph[n] { - j.Children = append(j.Children, r.numbering[callee]) + for callee := range call.CalleesOf(n) { + j.Children = append(j.Children, numbering[callee]) } } res.Callgraph = cg diff --git a/oracle/callstack.go b/oracle/callstack.go index 2c3aeba7..57adfff2 100644 --- a/oracle/callstack.go +++ b/oracle/callstack.go @@ -8,8 +8,8 @@ import ( "fmt" "go/token" + "code.google.com/p/go.tools/call" "code.google.com/p/go.tools/oracle/serial" - "code.google.com/p/go.tools/pointer" "code.google.com/p/go.tools/ssa" ) @@ -41,57 +41,36 @@ func callstack(o *Oracle, qpos *QueryPos) (queryResult, error) { } // Run the pointer analysis and build the complete call graph. - callgraph := make(pointer.CallGraph) - o.config.Call = callgraph.AddEdge - root := ptrAnalysis(o) + o.config.BuildCallGraph = true + callgraph := ptrAnalysis(o).CallGraph - seen := make(map[pointer.CallGraphNode]bool) - var callstack []pointer.CallSite - - // Use depth-first search to find an arbitrary path from a - // root to the target function. - var search func(cgn pointer.CallGraphNode) bool - search = func(cgn pointer.CallGraphNode) bool { - if !seen[cgn] { - seen[cgn] = true - if cgn.Func() == target { - return true - } - for callee, site := range callgraph[cgn] { - if search(callee) { - callstack = append(callstack, site) - return true - } - } - } - return false - } - - for toplevel := range callgraph[root] { - if search(toplevel) { - break - } + // Search for an arbitrary path from a root to the target function. + isEnd := func(n call.GraphNode) bool { return n.Func() == target } + callpath := call.PathSearch(callgraph.Root(), isEnd) + if callpath != nil { + callpath = callpath[1:] // remove synthetic edge from } return &callstackResult{ - qpos: qpos, - target: target, - callstack: callstack, + qpos: qpos, + target: target, + callpath: callpath, }, nil } type callstackResult struct { - qpos *QueryPos - target *ssa.Function - callstack []pointer.CallSite + qpos *QueryPos + target *ssa.Function + callpath []call.Edge } func (r *callstackResult) display(printf printfFunc) { - if r.callstack != nil { + if r.callpath != nil { printf(r.qpos, "Found a call path from root to %s", r.target) printf(r.target, "%s", r.target) - for _, site := range r.callstack { - printf(site, "%s from %s", site.Description(), site.Caller().Func()) + for i := len(r.callpath) - 1; i >= 0; i-- { + edge := r.callpath[i] + printf(edge.Site, "%s from %s", edge.Site.Common().Description(), edge.Caller.Func()) } } else { printf(r.target, "%s is unreachable in this analysis scope", r.target) @@ -100,11 +79,12 @@ func (r *callstackResult) display(printf printfFunc) { func (r *callstackResult) toSerial(res *serial.Result, fset *token.FileSet) { var callers []serial.Caller - for _, site := range r.callstack { + for i := len(r.callpath) - 1; i >= 0; i-- { // (innermost first) + edge := r.callpath[i] callers = append(callers, serial.Caller{ - Pos: fset.Position(site.Pos()).String(), - Caller: site.Caller().Func().String(), - Desc: site.Description(), + Pos: fset.Position(edge.Site.Pos()).String(), + Caller: edge.Caller.Func().String(), + Desc: edge.Site.Common().Description(), }) } res.Callstack = &serial.CallStack{ diff --git a/oracle/describe.go b/oracle/describe.go index 20941915..23da2175 100644 --- a/oracle/describe.go +++ b/oracle/describe.go @@ -419,11 +419,11 @@ func describePointer(o *Oracle, v ssa.Value, indirect bool) (ptrs []pointerResul buildSSA(o) // TODO(adonovan): don't run indirect pointer analysis on non-ptr-ptrlike types. - o.config.QueryValues = map[ssa.Value]pointer.Indirect{v: pointer.Indirect(indirect)} - ptrAnalysis(o) + o.config.Queries = map[ssa.Value]pointer.Indirect{v: pointer.Indirect(indirect)} + ptares := ptrAnalysis(o) // Combine the PT sets from all contexts. - pointers := o.config.QueryResults[v] + pointers := ptares.Queries[v] if pointers == nil { return nil, fmt.Errorf("PTA did not encounter this expression (dead code?)") } diff --git a/oracle/oracle.go b/oracle/oracle.go index f3b6781c..3c0a7510 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -358,14 +358,12 @@ func buildSSA(o *Oracle) { o.timers["SSA-build"] = time.Since(start) } -// ptrAnalysis runs the pointer analysis and returns the synthetic -// root of the callgraph. -// -func ptrAnalysis(o *Oracle) pointer.CallGraphNode { +// ptrAnalysis runs the pointer analysis and returns its result. +func ptrAnalysis(o *Oracle) *pointer.Result { start := time.Now() - root := pointer.Analyze(&o.config) + result := pointer.Analyze(&o.config) o.timers["pointer analysis"] = time.Since(start) - return root + return result } // parseOctothorpDecimal returns the numeric value if s matches "#%d", diff --git a/oracle/peers.go b/oracle/peers.go index 41010ae6..65e7ac2d 100644 --- a/oracle/peers.go +++ b/oracle/peers.go @@ -71,11 +71,11 @@ func peers(o *Oracle, qpos *QueryPos) (queryResult, error) { ops = ops[:i] // Run the pointer analysis. - o.config.QueryValues = channels - ptrAnalysis(o) + o.config.Queries = channels + ptares := ptrAnalysis(o) // Combine the PT sets from all contexts. - queryChanPts := pointer.PointsToCombined(o.config.QueryResults[queryOp.ch]) + queryChanPts := pointer.PointsToCombined(ptares.Queries[queryOp.ch]) // Ascertain which make(chan) labels the query's channel can alias. var makes []token.Pos @@ -87,7 +87,7 @@ func peers(o *Oracle, qpos *QueryPos) (queryResult, error) { // Ascertain which send/receive operations can alias the same make(chan) labels. var sends, receives []token.Pos for _, op := range ops { - for _, ptr := range o.config.QueryResults[op.ch] { + for _, ptr := range ptares.Queries[op.ch] { if ptr != nil && ptr.PointsTo().Intersects(queryChanPts) { if op.dir == ast.SEND { sends = append(sends, op.pos) diff --git a/oracle/testdata/src/main/calls.golden b/oracle/testdata/src/main/calls.golden index 259a7980..8ba18e8d 100644 --- a/oracle/testdata/src/main/calls.golden +++ b/oracle/testdata/src/main/calls.golden @@ -88,7 +88,7 @@ dynamic method call on nil value -------- @callees callees-err-deadcode2 -------- -Error: this call site is unreachable in this analysis +Error: this function is unreachable in this analysis -------- @callstack callstack-err-deadcode -------- main.deadcode is unreachable in this analysis scope diff --git a/pointer/TODO b/pointer/TODO index d5f12bc7..10afd49a 100644 --- a/pointer/TODO +++ b/pointer/TODO @@ -32,12 +32,6 @@ SOLVER: dannyb recommends sparse bitmap. API: -- Rely less on callbacks and more on a 'result' type - returned by Analyze(). -- Abstract the callgraph into a pure interface so that - we can provide other implementations in future (e.g. RTA-based). - Also provide the option to eliminate context-sensitivity - in a callgraph to yield a smaller (less precise) callgraph. - Some optimisations (e.g. LE, PE) may change the API. Think about them sooner rather than later. - Eliminate Print probe now that we can query specific ssa.Values. diff --git a/pointer/analysis.go b/pointer/analysis.go index b825bb11..d645365f 100644 --- a/pointer/analysis.go +++ b/pointer/analysis.go @@ -173,13 +173,14 @@ type analysis struct { nodes []*node // indexed by nodeid flattenMemo map[types.Type][]*fieldInfo // memoization of flatten() constraints []constraint // set of constraints - callsites []*callsite // all callsites + cgnodes []*cgnode // all cgnodes genq []*cgnode // queue of functions to generate constraints for intrinsics map[*ssa.Function]intrinsic // non-nil values are summaries for intrinsic fns funcObj map[*ssa.Function]nodeid // default function object for each func probes map[*ssa.CallCommon]nodeid // maps call to print() to argument variable valNode map[ssa.Value]nodeid // node for each ssa.Value work worklist // solver's worklist + queries map[ssa.Value][]Pointer // same as Results.Queries // Reflection: hasher typemap.Hasher // cache of type hashes @@ -229,7 +230,7 @@ func (a *analysis) warnf(pos token.Pos, format string, args ...interface{}) { // Analyze runs the pointer analysis with the scope and options // specified by config, and returns the (synthetic) root of the callgraph. // -func Analyze(config *Config) CallGraphNode { +func Analyze(config *Config) *Result { a := &analysis{ config: config, log: config.Log, @@ -241,6 +242,7 @@ func Analyze(config *Config) CallGraphNode { funcObj: make(map[*ssa.Function]nodeid), probes: make(map[*ssa.CallCommon]nodeid), work: makeMapWorklist(), + queries: make(map[ssa.Value][]Pointer), } if reflect := a.prog.ImportedPackage("reflect"); reflect != nil { @@ -265,7 +267,7 @@ func Analyze(config *Config) CallGraphNode { fmt.Fprintln(a.log, "======== NEW ANALYSIS ========") } - root := a.generate() + a.generate() //a.optimize() @@ -280,32 +282,33 @@ func Analyze(config *Config) CallGraphNode { } } - // Notify the client of the callsites if they're interested. - if CallSite := a.config.CallSite; CallSite != nil { - for _, site := range a.callsites { - CallSite(site) - } - } + // Visit discovered call graph. + for _, caller := range a.cgnodes { + for _, site := range caller.sites { + for nid := range a.nodes[site.targets].pts { + callee := a.nodes[nid].obj.cgn - Call := a.config.Call - for _, site := range a.callsites { - for nid := range a.nodes[site.targets].pts { - cgn := a.nodes[nid].obj.cgn + if a.config.BuildCallGraph { + site.callees = append(site.callees, callee) + } - // Notify the client of the call graph, if - // they're interested. - if Call != nil { - Call(site, cgn) - } - - // Warn about calls to non-intrinsic external functions. - - if fn := cgn.fn; fn.Blocks == nil && a.findIntrinsic(fn) == nil { - a.warnf(site.Pos(), "unsound call to unknown intrinsic: %s", fn) - a.warnf(fn.Pos(), " (declared here)") + // TODO(adonovan): de-dup these messages. + // Warn about calls to non-intrinsic external functions. + if fn := callee.fn; fn.Blocks == nil && a.findIntrinsic(fn) == nil { + a.warnf(site.pos(), "unsound call to unknown intrinsic: %s", fn) + a.warnf(fn.Pos(), " (declared here)") + } } } } - return root + var callgraph *cgraph + if a.config.BuildCallGraph { + callgraph = &cgraph{a.cgnodes} + } + + return &Result{ + CallGraph: callgraph, + Queries: a.queries, + } } diff --git a/pointer/api.go b/pointer/api.go index f1f4e472..3867b2a4 100644 --- a/pointer/api.go +++ b/pointer/api.go @@ -9,6 +9,7 @@ import ( "go/token" "io" + "code.google.com/p/go.tools/call" "code.google.com/p/go.tools/go/types/typemap" "code.google.com/p/go.tools/ssa" ) @@ -27,31 +28,12 @@ type Config struct { // has not yet been reduced by presolver optimisation. Reflection bool + // BuildCallGraph determines whether to construct a callgraph. + // If enabled, the graph will be available in Result.CallGraph. + BuildCallGraph bool + // -------- Optional callbacks invoked by the analysis -------- - // Call is invoked for each discovered call-graph edge. The - // call-graph is a multigraph over CallGraphNodes with edges - // labelled by the CallSite that gives rise to the edge. - // (The caller node is available as site.Caller()) - // - // Clients that wish to construct a call graph may provide - // CallGraph.AddEdge here. - // - // The callgraph may be context-sensitive, i.e. it may - // distinguish separate calls to the same function depending - // on the context. - // - Call func(site CallSite, callee CallGraphNode) - - // CallSite is invoked for each call-site encountered in the - // program. - // - // The callgraph may be context-sensitive, i.e. it may - // distinguish separate calls to the same function depending - // on the context. - // - CallSite func(site CallSite) - // Warn is invoked for each warning encountered by the analysis, // e.g. unknown external function, unsound use of unsafe.Pointer. // pos may be zero if the position is not known. @@ -71,8 +53,8 @@ type Config struct { // Print func(site *ssa.CallCommon, p Pointer) - // The client populates QueryValues[v] for each ssa.Value v - // of interest. + // The client populates Queries[v] for each ssa.Value v of + // interest. // // The boolean (Indirect) indicates whether to compute the // points-to set for v (false) or *v (true): the latter is @@ -80,20 +62,16 @@ type Config struct { // lvalues, e.g. an *ssa.Global. // // The pointer analysis will populate the corresponding - // QueryResults value when it creates the pointer variable - // for v or *v. Upon completion the client can inspect the + // Results.Queries value when it creates the pointer variable + // for v or *v. Upon completion the client can inspect that // map for the results. // // If a Value belongs to a function that the analysis treats - // context-sensitively, the corresponding QueryResults slice + // context-sensitively, the corresponding Results.Queries slice // may have multiple Pointers, one per distinct context. Use // PointsToCombined to merge them. // - // TODO(adonovan): refactor the API: separate all results of - // Analyze() into a dedicated Result struct. - // - QueryValues map[ssa.Value]Indirect - QueryResults map[ssa.Value][]Pointer + Queries map[ssa.Value]Indirect // -------- Other configuration options -------- @@ -111,10 +89,19 @@ func (c *Config) prog() *ssa.Program { panic("empty scope") } +// A Result contains the results of a pointer analysis. +// +// See Config for how to request the various Result components. +// +type Result struct { + CallGraph call.Graph // discovered call graph + Queries map[ssa.Value][]Pointer // points-to sets for queried ssa.Values +} + // A Pointer is an equivalence class of pointerlike values. // // TODO(adonovan): add a method -// Context() CallGraphNode +// Context() call.GraphNode // for pointers corresponding to local variables, // type Pointer interface { diff --git a/pointer/callgraph.go b/pointer/callgraph.go index af328a9e..e3d8bcea 100644 --- a/pointer/callgraph.go +++ b/pointer/callgraph.go @@ -4,116 +4,93 @@ package pointer +// This file defines our implementation of the call.Graph API. + import ( "fmt" "go/token" + "code.google.com/p/go.tools/call" "code.google.com/p/go.tools/ssa" ) -// TODO(adonovan): move the CallGraph, CallGraphNode, CallSite types -// into a separate package 'callgraph', and make them pure interfaces -// capable of supporting several implementations (context-sensitive -// and insensitive PTA, RTA, etc). - -// ---------- CallGraphNode ---------- - -// A CallGraphNode is a context-sensitive representation of a node in -// the callgraph. In other words, there may be multiple nodes -// representing a single *Function, depending on the contexts in which -// it is called. The identity of the node is therefore important. -// -type CallGraphNode interface { - Func() *ssa.Function // the function this node represents - String() string // diagnostic description of this callgraph node +// cgraph implements call.Graph. +type cgraph struct { + nodes []*cgnode } +func (g *cgraph) Nodes() []call.GraphNode { + nodes := make([]call.GraphNode, len(g.nodes)) + for i, node := range g.nodes { + nodes[i] = node + } + return nodes +} + +func (g *cgraph) Root() call.GraphNode { + return g.nodes[0] +} + +// cgnode implements call.GraphNode. type cgnode struct { - fn *ssa.Function - obj nodeid // start of this contour's object block + fn *ssa.Function + obj nodeid // start of this contour's object block + sites []*callsite // ordered list of callsites within this function } func (n *cgnode) Func() *ssa.Function { return n.fn } +func (n *cgnode) Sites() []ssa.CallInstruction { + sites := make([]ssa.CallInstruction, len(n.sites)) + for i, site := range n.sites { + sites[i] = site.instr + } + return sites +} + +func (n *cgnode) Edges() []call.Edge { + var numEdges int + for _, site := range n.sites { + numEdges += len(site.callees) + } + edges := make([]call.Edge, 0, numEdges) + + for _, site := range n.sites { + for _, callee := range site.callees { + edges = append(edges, call.Edge{Caller: n, Site: site.instr, Callee: callee}) + } + } + return edges +} + func (n *cgnode) String() string { return fmt.Sprintf("cg%d:%s", n.obj, n.fn) } -// ---------- CallSite ---------- - -// A CallSite is a context-sensitive representation of a function call -// site in the program. -// -type CallSite interface { - Caller() CallGraphNode // the enclosing context of this call - Pos() token.Pos // source position; token.NoPos for synthetic calls - Description() string // UI description of call kind; see (*ssa.CallCommon).Description - String() string // diagnostic description of this callsite -} - -// A callsite represents a single function or method callsite within a -// function. callsites never represent calls to built-ins; they are -// handled as intrinsics. +// A callsite represents a single call site within a cgnode; +// it is implicitly context-sensitive. +// callsites never represent calls to built-ins; +// they are handled as intrinsics. // type callsite struct { - caller *cgnode // the origin of the call targets nodeid // pts(targets) contains identities of all called functions. - instr ssa.CallInstruction // optional call instruction; provides IsInvoke, position, etc. - pos token.Pos // position, if instr == nil, i.e. synthetic callsites. + instr ssa.CallInstruction // the call instruction; nil for synthetic/intrinsic + callees []*cgnode // unordered set of callees of this site } -// Caller returns the node in the callgraph from which this call originated. -func (c *callsite) Caller() CallGraphNode { - return c.caller -} - -// Description returns a description of this kind of call, in the -// manner of ssa.CallCommon.Description(). -// -func (c *callsite) Description() string { +func (c *callsite) String() string { if c.instr != nil { return c.instr.Common().Description() } return "synthetic function call" } -// Pos returns the source position of this callsite, or token.NoPos if implicit. -func (c *callsite) Pos() token.Pos { +// pos returns the source position of this callsite, or token.NoPos if implicit. +func (c *callsite) pos() token.Pos { if c.instr != nil { return c.instr.Pos() } - return c.pos -} - -func (c *callsite) String() string { - // TODO(adonovan): provide more info, e.g. target of static - // call, arguments, location. - return c.Description() -} - -// ---------- CallGraph ---------- - -// CallGraph is a forward directed graph of functions labelled by an -// arbitrary site within the caller. -// -// CallGraph.AddEdge may be used as the Context.Call callback for -// clients that wish to construct a call graph. -// -// TODO(adonovan): this is just a starting point. Add options to -// control whether we record no callsite, an arbitrary callsite, or -// all callsites for a given graph edge. Also, this could live in -// another package since it's just a client utility. -// -type CallGraph map[CallGraphNode]map[CallGraphNode]CallSite - -func (cg CallGraph) AddEdge(site CallSite, callee CallGraphNode) { - caller := site.Caller() - callees := cg[caller] - if callees == nil { - callees = make(map[CallGraphNode]CallSite) - cg[caller] = callees - } - callees[callee] = site // save an arbitrary site + return token.NoPos } diff --git a/pointer/example_test.go b/pointer/example_test.go index fae61954..bd86818a 100644 --- a/pointer/example_test.go +++ b/pointer/example_test.go @@ -10,6 +10,7 @@ import ( "go/parser" "sort" + "code.google.com/p/go.tools/call" "code.google.com/p/go.tools/importer" "code.google.com/p/go.tools/pointer" "code.google.com/p/go.tools/ssa" @@ -66,34 +67,23 @@ func main() { prog.BuildAll() // Run the pointer analysis and build the complete callgraph. - callgraph := make(pointer.CallGraph) config := &pointer.Config{ - Mains: []*ssa.Package{mainPkg}, - Call: callgraph.AddEdge, + Mains: []*ssa.Package{mainPkg}, + BuildCallGraph: true, } - root := pointer.Analyze(config) + result := pointer.Analyze(config) - // Visit callgraph in depth-first order. - // - // There may be multiple nodes for the - // same function due to context sensitivity. - var edges []string // call edges originating from the main package. - seen := make(map[pointer.CallGraphNode]bool) - var visit func(cgn pointer.CallGraphNode) - visit = func(cgn pointer.CallGraphNode) { - if seen[cgn] { - return // already seen + // Find edges originating from the main package. + // By converting to strings, we de-duplicate nodes + // representing the same function due to context sensitivity. + var edges []string + call.GraphVisitEdges(result.CallGraph, func(edge call.Edge) error { + caller := edge.Caller.Func() + if caller.Pkg == mainPkg { + edges = append(edges, fmt.Sprint(caller, " --> ", edge.Callee.Func())) } - seen[cgn] = true - caller := cgn.Func() - for callee := range callgraph[cgn] { - if caller.Pkg == mainPkg { - edges = append(edges, fmt.Sprint(caller, " --> ", callee.Func())) - } - visit(callee) - } - } - visit(root) + return nil + }) // Print the edges in sorted order. sort.Strings(edges) diff --git a/pointer/gen.go b/pointer/gen.go index 62fb1039..cebeff5a 100644 --- a/pointer/gen.go +++ b/pointer/gen.go @@ -72,18 +72,13 @@ func (a *analysis) setValueNode(v ssa.Value, id nodeid) { } // Record the (v, id) relation if the client has queried v. - if indirect, ok := a.config.QueryValues[v]; ok { + if indirect, ok := a.config.Queries[v]; ok { if indirect { tmp := a.addNodes(v.Type(), "query.indirect") a.load(tmp, id, a.sizeof(v.Type())) id = tmp } - ptrs := a.config.QueryResults - if ptrs == nil { - ptrs = make(map[ssa.Value][]Pointer) - a.config.QueryResults = ptrs - } - ptrs[v] = append(ptrs[v], ptr{a, id}) + a.queries[v] = append(a.queries[v], ptr{a, id}) } } @@ -125,7 +120,7 @@ func (a *analysis) makeFunctionObject(fn *ssa.Function) nodeid { // obj is the function object (identity, params, results). obj := a.nextNode() - cgn := &cgnode{fn: fn, obj: obj} + cgn := a.makeCGNode(fn, obj) sig := fn.Signature a.addOneNode(sig, "func.cgnode", nil) // (scalar with Signature type) if recv := sig.Recv(); recv != nil { @@ -849,15 +844,12 @@ func (a *analysis) genCall(caller *cgnode, instr ssa.CallInstruction) { } site := &callsite{ - caller: caller, targets: targets, instr: instr, - pos: instr.Pos(), } - a.callsites = append(a.callsites, site) + caller.sites = append(caller.sites, site) if a.log != nil { - fmt.Fprintf(a.log, "\t%s to targets %s from %s\n", - site.Description(), site.targets, site.caller) + fmt.Fprintf(a.log, "\t%s to targets %s from %s\n", site, site.targets, caller) } } @@ -1061,6 +1053,12 @@ func (a *analysis) genInstr(cgn *cgnode, instr ssa.Instruction) { } } +func (a *analysis) makeCGNode(fn *ssa.Function, obj nodeid) *cgnode { + cgn := &cgnode{fn: fn, obj: obj} + a.cgnodes = append(a.cgnodes, cgn) + return cgn +} + // genRootCalls generates the synthetic root of the callgraph and the // initial calls from it to the analysis scope, such as main, a test // or a library. @@ -1070,7 +1068,7 @@ func (a *analysis) genRootCalls() *cgnode { r.Prog = a.prog // hack. r.Enclosing = r // hack, so Function.String() doesn't crash r.String() // (asserts that it doesn't crash) - root := &cgnode{fn: r} + root := a.makeCGNode(r, 0) // For each main package, call main.init(), main.main(). for _, mainPkg := range a.config.Mains { @@ -1080,11 +1078,8 @@ func (a *analysis) genRootCalls() *cgnode { } targets := a.addOneNode(main.Signature, "root.targets", nil) - site := &callsite{ - caller: root, - targets: targets, - } - a.callsites = append(a.callsites, site) + site := &callsite{targets: targets} + root.sites = append(root.sites, site) for _, fn := range [2]*ssa.Function{mainPkg.Func("init"), main} { if a.log != nil { fmt.Fprintf(a.log, "\troot call to %s:\n", fn) diff --git a/pointer/labels.go b/pointer/labels.go index 6f33f1fa..215d24c9 100644 --- a/pointer/labels.go +++ b/pointer/labels.go @@ -9,6 +9,7 @@ import ( "go/token" "strings" + "code.google.com/p/go.tools/call" "code.google.com/p/go.tools/go/types" "code.google.com/p/go.tools/ssa" ) @@ -46,7 +47,7 @@ func (l Label) Value() ssa.Value { // Context returns the analytic context in which this label's object was allocated, // or nil for global objects: global, const, and shared contours for functions. // -func (l Label) Context() CallGraphNode { +func (l Label) Context() call.GraphNode { return l.obj.cgn } diff --git a/pointer/pointer_test.go b/pointer/pointer_test.go index 8799e719..7f4a500c 100644 --- a/pointer/pointer_test.go +++ b/pointer/pointer_test.go @@ -10,6 +10,7 @@ package pointer_test import ( "bytes" + "errors" "fmt" "go/build" "go/parser" @@ -21,6 +22,7 @@ import ( "strings" "testing" + "code.google.com/p/go.tools/call" "code.google.com/p/go.tools/go/types" "code.google.com/p/go.tools/go/types/typemap" "code.google.com/p/go.tools/importer" @@ -286,24 +288,22 @@ func doOneInput(input, filename string) bool { var warnings []string var log bytes.Buffer - callgraph := make(pointer.CallGraph) - // Run the analysis. config := &pointer.Config{ - Reflection: true, - Mains: []*ssa.Package{ptrmain}, - Log: &log, + Reflection: true, + BuildCallGraph: true, + Mains: []*ssa.Package{ptrmain}, + Log: &log, Print: func(site *ssa.CallCommon, p pointer.Pointer) { probes = append(probes, probe{site, p}) }, - Call: callgraph.AddEdge, Warn: func(pos token.Pos, format string, args ...interface{}) { msg := fmt.Sprintf(format, args...) fmt.Printf("%s: warning: %s\n", prog.Fset.Position(pos), msg) warnings = append(warnings, msg) }, } - pointer.Analyze(config) + result := pointer.Analyze(config) // Print the log is there was an error or a panic. complete := false @@ -341,7 +341,7 @@ func doOneInput(input, filename string) bool { } case "calls": - if !checkCallsExpectation(prog, e, callgraph) { + if !checkCallsExpectation(prog, e, result.CallGraph) { ok = false } @@ -463,29 +463,33 @@ func checkTypesExpectation(e *expectation, pr *probe) bool { e.errorf("interface may additionally contain these types: %s", surplus.KeysString()) } return ok - return false } -func checkCallsExpectation(prog *ssa.Program, e *expectation, callgraph pointer.CallGraph) bool { - // TODO(adonovan): this is inefficient and not robust against - // typos. Better to convert strings to *Functions during - // expectation parsing (somehow). - for caller, callees := range callgraph { - if caller.Func().String() == e.args[0] { - found := make(map[string]struct{}) - for callee := range callees { - s := callee.Func().String() - found[s] = struct{}{} - if s == e.args[1] { - return true // expectation satisfied - } +var errOK = errors.New("OK") + +func checkCallsExpectation(prog *ssa.Program, e *expectation, callgraph call.Graph) bool { + found := make(map[string]struct{}) + err := call.GraphVisitEdges(callgraph, func(edge call.Edge) error { + // Name-based matching is inefficient but it allows us to + // match functions whose names that would not appear in an + // index ("") or which are not unique ("func@1.2"). + if edge.Caller.Func().String() == e.args[0] { + calleeStr := edge.Callee.Func().String() + if calleeStr == e.args[1] { + return errOK // expectation satisified; stop the search } - e.errorf("found no call from %s to %s, but only to %s", - e.args[0], e.args[1], join(found)) - return false + found[calleeStr] = struct{}{} } + return nil + }) + if err == errOK { + return true } - e.errorf("didn't find any calls from %s", e.args[0]) + if len(found) == 0 { + e.errorf("didn't find any calls from %s", e.args[0]) + } + e.errorf("found no call from %s to %s, but only to %s", + e.args[0], e.args[1], join(found)) return false }