From 500e9560000dc4521479f08d7f612f662a98f435 Mon Sep 17 00:00:00 2001 From: Daniel Morsing Date: Mon, 18 May 2015 22:53:22 +0100 Subject: [PATCH] oracle: attempt to deduce callees statically before building SSA When querying for callees against a static call, the entire SSA form for the program was built. Since we can tell if a callee is statically dispatched after typechecking, try to do that before building the SSA form. This cuts 3.5 seconds off queries against static calls. Change-Id: I22291381d3bec490e3b1d6f9c6b5a0092fd9f635 Reviewed-on: https://go-review.googlesource.com/10230 Reviewed-by: Alan Donovan --- oracle/callees.go | 117 +++++++++++++++++++------- oracle/testdata/src/calls/main.go | 18 ++++ oracle/testdata/src/calls/main.golden | 12 +++ 3 files changed, 118 insertions(+), 29 deletions(-) diff --git a/oracle/callees.go b/oracle/callees.go index f05c8b28..56e45e13 100644 --- a/oracle/callees.go +++ b/oracle/callees.go @@ -39,18 +39,6 @@ func callees(q *Query) error { return err } - prog := ssautil.CreateProgram(lprog, 0) - - ptaConfig, err := setupPTA(prog, lprog, q.PTALog, q.Reflection) - if err != nil { - return err - } - - pkg := prog.Package(qpos.info.Pkg) - if pkg == nil { - return fmt.Errorf("no SSA package") - } - // Determine the enclosing call for the specified position. var e *ast.CallExpr for _, n := range qpos.path { @@ -70,11 +58,60 @@ func callees(q *Query) error { return fmt.Errorf("this is a type conversion, not a function call") } - // Reject calls to built-ins. - if id, ok := unparen(e.Fun).(*ast.Ident); ok { - if b, ok := qpos.info.Uses[id].(*types.Builtin); ok { - return fmt.Errorf("this is a call to the built-in '%s' operator", b.Name()) + // Deal with obviously static calls before constructing SSA form. + // Some static calls may yet require SSA construction, + // e.g. f := func(){}; f(). + switch funexpr := unparen(e.Fun).(type) { + case *ast.Ident: + switch obj := qpos.info.Uses[funexpr].(type) { + case *types.Builtin: + // Reject calls to built-ins. + return fmt.Errorf("this is a call to the built-in '%s' operator", obj.Name()) + case *types.Func: + // This is a static function call + q.result = &calleesTypesResult{ + site: e, + callee: obj, + } + return nil } + case *ast.SelectorExpr: + sel := qpos.info.Selections[funexpr] + if sel == nil { + // qualified identifier. + // May refer to top level function variable + // or to top level function. + callee := qpos.info.Uses[funexpr.Sel] + if obj, ok := callee.(*types.Func); ok { + q.result = &calleesTypesResult{ + site: e, + callee: obj, + } + return nil + } + } else if sel.Kind() == types.MethodVal { + recvtype := sel.Recv() + if !types.IsInterface(recvtype) { + // static method call + q.result = &calleesTypesResult{ + site: e, + callee: sel.Obj().(*types.Func), + } + return nil + } + } + } + + prog := ssautil.CreateProgram(lprog, ssa.GlobalDebug) + + ptaConfig, err := setupPTA(prog, lprog, q.PTALog, q.Reflection) + if err != nil { + return err + } + + pkg := prog.Package(qpos.info.Pkg) + if pkg == nil { + return fmt.Errorf("no SSA package") } // Defer SSA construction till after errors are reported. @@ -87,7 +124,7 @@ func callees(q *Query) error { } // Find the call site. - site, err := findCallSite(callerFn, e.Lparen) + site, err := findCallSite(callerFn, e) if err != nil { return err } @@ -97,22 +134,20 @@ func callees(q *Query) error { return err } - q.result = &calleesResult{ + q.result = &calleesSSAResult{ site: site, funcs: funcs, } return nil } -func findCallSite(fn *ssa.Function, lparen token.Pos) (ssa.CallInstruction, error) { - for _, b := range fn.Blocks { - for _, instr := range b.Instrs { - if site, ok := instr.(ssa.CallInstruction); ok && instr.Pos() == lparen { - return site, nil - } - } +func findCallSite(fn *ssa.Function, call *ast.CallExpr) (ssa.CallInstruction, error) { + instr, _ := fn.ValueForExpr(call) + callInstr, _ := instr.(ssa.CallInstruction) + if instr == nil { + return nil, fmt.Errorf("this call site is unreachable in this analysis") } - return nil, fmt.Errorf("this call site is unreachable in this analysis") + return callInstr, nil } func findCallees(conf *pointer.Config, site ssa.CallInstruction) ([]*ssa.Function, error) { @@ -154,12 +189,17 @@ func findCallees(conf *pointer.Config, site ssa.CallInstruction) ([]*ssa.Functio return funcs, nil } -type calleesResult struct { +type calleesSSAResult struct { site ssa.CallInstruction funcs []*ssa.Function } -func (r *calleesResult) display(printf printfFunc) { +type calleesTypesResult struct { + site *ast.CallExpr + callee *types.Func +} + +func (r *calleesSSAResult) display(printf printfFunc) { if len(r.funcs) == 0 { // dynamic call on a provably nil func/interface printf(r.site, "%s on nil value", r.site.Common().Description()) @@ -171,7 +211,7 @@ func (r *calleesResult) display(printf printfFunc) { } } -func (r *calleesResult) toSerial(res *serial.Result, fset *token.FileSet) { +func (r *calleesSSAResult) toSerial(res *serial.Result, fset *token.FileSet) { j := &serial.Callees{ Pos: fset.Position(r.site.Pos()).String(), Desc: r.site.Common().Description(), @@ -185,6 +225,25 @@ func (r *calleesResult) toSerial(res *serial.Result, fset *token.FileSet) { res.Callees = j } +func (r *calleesTypesResult) display(printf printfFunc) { + printf(r.site, "this static function call dispatches to:") + printf(r.callee, "\t%s", r.callee.FullName()) +} + +func (r *calleesTypesResult) toSerial(res *serial.Result, fset *token.FileSet) { + j := &serial.Callees{ + Pos: fset.Position(r.site.Pos()).String(), + Desc: "static function call", + } + j.Callees = []*serial.CalleesItem{ + &serial.CalleesItem{ + Name: r.callee.FullName(), + Pos: fset.Position(r.callee.Pos()).String(), + }, + } + res.Callees = j +} + // NB: byFuncPos is not deterministic across packages since it depends on load order. // Use lessPos if the tests need it. type byFuncPos []*ssa.Function diff --git a/oracle/testdata/src/calls/main.go b/oracle/testdata/src/calls/main.go index 7c54e0e9..6fdff95d 100644 --- a/oracle/testdata/src/calls/main.go +++ b/oracle/testdata/src/calls/main.go @@ -1,5 +1,9 @@ package main +import ( + "fmt" +) + // Tests of call-graph queries. // See go.tools/oracle/oracle_test.go for explanation. // See calls.golden for expected query results. @@ -13,6 +17,9 @@ func B(x *int) { // @pointsto pointsto-B-x "x" // @callers callers-B "^" } +func foo() { +} + // apply is not (yet) treated context-sensitively. func apply(f func(x *int), x *int) { f(x) // @callees callees-apply "f" @@ -70,6 +77,12 @@ func main() { i = new(myint) i.f() // @callees callees-not-a-wrapper "f" + + // statically dispatched calls. Handled specially by callees, so test that they work. + foo() // @callees callees-static-call "foo" + fmt.Println() // @callees callees-qualified-call "Println" + m := new(method) + m.f() // @callees callees-static-method-call "f" } type myint int @@ -78,6 +91,11 @@ func (myint) f() { // @callers callers-not-a-wrapper "^" } +type method int + +func (method) f() { +} + var dynamic = func() {} func deadcode() { diff --git a/oracle/testdata/src/calls/main.golden b/oracle/testdata/src/calls/main.golden index c06f0e85..f6bfb0ea 100644 --- a/oracle/testdata/src/calls/main.golden +++ b/oracle/testdata/src/calls/main.golden @@ -84,6 +84,18 @@ dynamic method call on nil value this dynamic method call dispatches to: (main.myint).f +-------- @callees callees-static-call -------- +this static function call dispatches to: + main.foo + +-------- @callees callees-qualified-call -------- +this static function call dispatches to: + fmt.Println + +-------- @callees callees-static-method-call -------- +this static function call dispatches to: + (main.method).f + -------- @callers callers-not-a-wrapper -------- (main.myint).f is called from these 1 sites: dynamic method call from main.main