diff --git a/go/pointer/analysis.go b/go/pointer/analysis.go index 11541823..0abb04dd 100644 --- a/go/pointer/analysis.go +++ b/go/pointer/analysis.go @@ -177,6 +177,11 @@ func (a *analysis) warnf(pos token.Pos, format string, args ...interface{}) { // computeTrackBits sets a.track to the necessary 'track' bits for the pointer queries. func (a *analysis) computeTrackBits() { + if len(a.config.extendedQueries) != 0 { + // TODO(dh): only track the types necessary for the query. + a.track = trackAll + return + } var queryTypes []types.Type for v := range a.config.Queries { queryTypes = append(queryTypes, v.Type()) diff --git a/go/pointer/api.go b/go/pointer/api.go index 8f9ae0ab..2b7e79e6 100644 --- a/go/pointer/api.go +++ b/go/pointer/api.go @@ -16,7 +16,9 @@ import ( "golang.org/x/tools/go/types/typeutil" ) -// A Config formulates a pointer analysis problem for Analyze(). +// A Config formulates a pointer analysis problem for Analyze. It is +// only usable for a single invocation of Analyze and must not be +// reused. type Config struct { // Mains contains the set of 'main' packages to analyze // Clients must provide the analysis with at least one @@ -61,6 +63,7 @@ type Config struct { // Queries map[ssa.Value]struct{} IndirectQueries map[ssa.Value]struct{} + extendedQueries map[ssa.Value][]*extendedQuery // If Log is non-nil, log messages are written to it. // Logging is extremely verbose. @@ -80,9 +83,6 @@ const ( // AddQuery adds v to Config.Queries. // Precondition: CanPoint(v.Type()). -// TODO(adonovan): consider returning a new Pointer for this query, -// which will be initialized during analysis. That avoids the needs -// for the corresponding ssa.Value-keyed maps in Config and Result. func (c *Config) AddQuery(v ssa.Value) { if !CanPoint(v.Type()) { panic(fmt.Sprintf("%s is not a pointer-like value: %s", v, v.Type())) @@ -105,6 +105,46 @@ func (c *Config) AddIndirectQuery(v ssa.Value) { c.IndirectQueries[v] = struct{}{} } +// AddExtendedQuery adds an extended, AST-based query on v to the +// analysis. The query, which must be a single Go expression, allows +// destructuring the value. +// +// The query must operate on a variable named 'x', which represents +// the value, and result in a pointer-like object. Only a subset of +// Go expressions are permitted in queries, namely channel receives, +// pointer dereferences, field selectors, array/slice/map/tuple +// indexing and grouping with parentheses. The specific indices when +// indexing arrays, slices and maps have no significance. Indices used +// on tuples must be numeric and within bounds. +// +// All field selectors must be explicit, even ones usually elided +// due to promotion of embedded fields. +// +// The query 'x' is identical to using AddQuery. The query '*x' is +// identical to using AddIndirectQuery. +// +// On success, AddExtendedQuery returns a Pointer to the queried +// value. This Pointer will be initialized during analysis. Using it +// before analysis has finished has undefined behavior. +// +// Example: +// // given v, which represents a function call to 'fn() (int, []*T)', and +// // 'type T struct { F *int }', the following query will access the field F. +// c.AddExtendedQuery(v, "x[1][0].F") +func (c *Config) AddExtendedQuery(v ssa.Value, query string) (*Pointer, error) { + ops, _, err := parseExtendedQuery(v.Type().Underlying(), query) + if err != nil { + return nil, fmt.Errorf("invalid query %q: %s", query, err) + } + if c.extendedQueries == nil { + c.extendedQueries = make(map[ssa.Value][]*extendedQuery) + } + + ptr := &Pointer{} + c.extendedQueries[v] = append(c.extendedQueries[v], &extendedQuery{ops: ops, ptr: ptr}) + return ptr, nil +} + func (c *Config) prog() *ssa.Program { for _, main := range c.Mains { return main.Prog diff --git a/go/pointer/gen.go b/go/pointer/gen.go index d62a8cac..a111175c 100644 --- a/go/pointer/gen.go +++ b/go/pointer/gen.go @@ -108,6 +108,16 @@ func (a *analysis) setValueNode(v ssa.Value, id nodeid, cgn *cgnode) { } a.genLoad(cgn, ptr.n, v, 0, a.sizeof(t)) } + + for _, query := range a.config.extendedQueries[v] { + t, nid := a.evalExtendedQuery(v.Type().Underlying(), id, query.ops) + + if query.ptr.a == nil { + query.ptr.a = a + query.ptr.n = a.addNodes(t, "query.extended") + } + a.copy(query.ptr.n, nid, a.sizeof(t)) + } } // endObject marks the end of a sequence of calls to addNodes denoting diff --git a/go/pointer/opt.go b/go/pointer/opt.go index 2620cc0d..81f80aa4 100644 --- a/go/pointer/opt.go +++ b/go/pointer/opt.go @@ -102,6 +102,13 @@ func (a *analysis) renumber() { ptr.n = renumbering[ptr.n] a.result.IndirectQueries[v] = ptr } + for _, queries := range a.config.extendedQueries { + for _, query := range queries { + if query.ptr != nil { + query.ptr.n = renumbering[query.ptr.n] + } + } + } // Renumber nodeids in global objects. for v, id := range a.globalobj { diff --git a/go/pointer/pointer_test.go b/go/pointer/pointer_test.go index 8918406b..573de5bf 100644 --- a/go/pointer/pointer_test.go +++ b/go/pointer/pointer_test.go @@ -42,6 +42,7 @@ var inputs = []string{ "testdata/chanreflect.go", "testdata/context.go", "testdata/conv.go", + "testdata/extended.go", "testdata/finalizer.go", "testdata/flow.go", "testdata/fmtexcerpt.go", @@ -120,11 +121,13 @@ var inputs = []string{ // (NB, anon functions still include line numbers.) // type expectation struct { - kind string // "pointsto" | "types" | "calls" | "warning" + kind string // "pointsto" | "pointstoquery" | "types" | "calls" | "warning" filename string linenum int // source line number, 1-based args []string - types []types.Type // for types + query string // extended query + extended *pointer.Pointer // extended query pointer + types []types.Type // for types } func (e *expectation) String() string { @@ -138,7 +141,7 @@ func (e *expectation) errorf(format string, args ...interface{}) { } func (e *expectation) needsProbe() bool { - return e.kind == "pointsto" || e.kind == "types" + return e.kind == "pointsto" || e.kind == "pointstoquery" || e.kind == "types" } // Find probe (call to print(x)) of same source file/line as expectation. @@ -239,6 +242,10 @@ func doOneInput(input, filename string) bool { case "pointsto": e.args = split(rest, "|") + case "pointstoquery": + args := strings.SplitN(rest, " ", 2) + e.query = args[0] + e.args = split(args[1], "|") case "types": for _, typstr := range split(rest, "|") { var t types.Type = types.Typ[types.Invalid] // means "..." @@ -295,8 +302,20 @@ func doOneInput(input, filename string) bool { Mains: []*ssa.Package{ptrmain}, Log: &log, } +probeLoop: for probe := range probes { v := probe.Args[0] + pos := prog.Fset.Position(probe.Pos()) + for _, e := range exps { + if e.linenum == pos.Line && e.filename == pos.Filename && e.kind == "pointstoquery" { + var err error + e.extended, err = config.AddExtendedQuery(v, e.query) + if err != nil { + panic(err) + } + continue probeLoop + } + } if pointer.CanPoint(v.Type()) { config.AddQuery(v) } @@ -326,6 +345,9 @@ func doOneInput(input, filename string) bool { e.errorf("unreachable print() statement has expectation %s", e) continue } + if e.extended != nil { + pts = e.extended.PointsTo() + } tProbe = call.Args[0].Type() if !pointer.CanPoint(tProbe) { ok = false @@ -335,7 +357,7 @@ func doOneInput(input, filename string) bool { } switch e.kind { - case "pointsto": + case "pointsto", "pointstoquery": if !checkPointsToExpectation(e, pts, lineMapping, prog) { ok = false } diff --git a/go/pointer/query.go b/go/pointer/query.go new file mode 100644 index 00000000..1263d318 --- /dev/null +++ b/go/pointer/query.go @@ -0,0 +1,221 @@ +package pointer + +import ( + "errors" + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "strconv" +) + +// An extendedQuery represents a sequence of destructuring operations +// applied to an ssa.Value (denoted by "x"). +type extendedQuery struct { + ops []interface{} + ptr *Pointer +} + +// indexValue returns the value of an integer literal used as an +// index. +func indexValue(expr ast.Expr) (int, error) { + lit, ok := expr.(*ast.BasicLit) + if !ok { + return 0, fmt.Errorf("non-integer index (%T)", expr) + } + if lit.Kind != token.INT { + return 0, fmt.Errorf("non-integer index %s", lit.Value) + } + return strconv.Atoi(lit.Value) +} + +// parseExtendedQuery parses and validates a destructuring Go +// expression and returns the sequence of destructuring operations. +// See parseDestructuringExpr for details. +func parseExtendedQuery(typ types.Type, query string) ([]interface{}, types.Type, error) { + expr, err := parser.ParseExpr(query) + if err != nil { + return nil, nil, err + } + ops, typ, err := destructuringOps(typ, expr) + if err != nil { + return nil, nil, err + } + if len(ops) == 0 { + return nil, nil, errors.New("invalid query: must not be empty") + } + if ops[0] != "x" { + return nil, nil, fmt.Errorf("invalid query: query operand must be named x") + } + if !CanPoint(typ) { + return nil, nil, fmt.Errorf("query does not describe a pointer-like value: %s", typ) + } + return ops, typ, nil +} + +// destructuringOps parses a Go expression consisting only of an +// identifier "x", field selections, indexing, channel receives, load +// operations and parens---for example: "<-(*x[i])[key]"--- and +// returns the sequence of destructuring operations on x. +func destructuringOps(typ types.Type, expr ast.Expr) ([]interface{}, types.Type, error) { + switch expr := expr.(type) { + case *ast.SelectorExpr: + out, typ, err := destructuringOps(typ, expr.X) + if err != nil { + return nil, nil, err + } + + var structT *types.Struct + switch typ := typ.(type) { + case *types.Pointer: + var ok bool + structT, ok = typ.Elem().Underlying().(*types.Struct) + if !ok { + return nil, nil, fmt.Errorf("cannot access field %s of pointer to type %s", expr.Sel.Name, typ.Elem()) + } + + out = append(out, "load") + case *types.Struct: + structT = typ + default: + return nil, nil, fmt.Errorf("cannot access field %s of type %s", expr.Sel.Name, typ) + } + + for i := 0; i < structT.NumFields(); i++ { + field := structT.Field(i) + if field.Name() == expr.Sel.Name { + out = append(out, "field", i) + return out, field.Type().Underlying(), nil + } + } + // TODO(dh): supporting embedding would need something like + // types.LookupFieldOrMethod, but without taking package + // boundaries into account, because we may want to access + // unexported fields. If we were only interested in one level + // of unexported name, we could determine the appropriate + // package and run LookupFieldOrMethod with that. However, a + // single query may want to cross multiple package boundaries, + // and at this point it's not really worth the complexity. + return nil, nil, fmt.Errorf("no field %s in %s (embedded fields must be resolved manually)", expr.Sel.Name, structT) + case *ast.Ident: + return []interface{}{expr.Name}, typ, nil + case *ast.BasicLit: + return []interface{}{expr.Value}, nil, nil + case *ast.IndexExpr: + out, typ, err := destructuringOps(typ, expr.X) + if err != nil { + return nil, nil, err + } + switch typ := typ.(type) { + case *types.Array: + out = append(out, "arrayelem") + return out, typ.Elem().Underlying(), nil + case *types.Slice: + out = append(out, "sliceelem") + return out, typ.Elem().Underlying(), nil + case *types.Map: + out = append(out, "mapelem") + return out, typ.Elem().Underlying(), nil + case *types.Tuple: + out = append(out, "index") + idx, err := indexValue(expr.Index) + if err != nil { + return nil, nil, err + } + out = append(out, idx) + if idx >= typ.Len() || idx < 0 { + return nil, nil, fmt.Errorf("tuple index %d out of bounds", idx) + } + return out, typ.At(idx).Type().Underlying(), nil + default: + return nil, nil, fmt.Errorf("cannot index type %s", typ) + } + + case *ast.UnaryExpr: + if expr.Op != token.ARROW { + return nil, nil, fmt.Errorf("unsupported unary operator %s", expr.Op) + } + out, typ, err := destructuringOps(typ, expr.X) + if err != nil { + return nil, nil, err + } + ch, ok := typ.(*types.Chan) + if !ok { + return nil, nil, fmt.Errorf("cannot receive from value of type %s", typ) + } + out = append(out, "recv") + return out, ch.Elem().Underlying(), err + case *ast.ParenExpr: + return destructuringOps(typ, expr.X) + case *ast.StarExpr: + out, typ, err := destructuringOps(typ, expr.X) + if err != nil { + return nil, nil, err + } + ptr, ok := typ.(*types.Pointer) + if !ok { + return nil, nil, fmt.Errorf("cannot dereference type %s", typ) + } + out = append(out, "load") + return out, ptr.Elem().Underlying(), err + default: + return nil, nil, fmt.Errorf("unsupported expression %T", expr) + } +} + +func (a *analysis) evalExtendedQuery(t types.Type, id nodeid, ops []interface{}) (types.Type, nodeid) { + pid := id + // TODO(dh): we're allocating intermediary nodes each time + // evalExtendedQuery is called. We should probably only generate + // them once per (v, ops) pair. + for i := 1; i < len(ops); i++ { + var nid nodeid + switch ops[i] { + case "recv": + t = t.(*types.Chan).Elem().Underlying() + nid = a.addNodes(t, "query.extended") + a.load(nid, pid, 0, a.sizeof(t)) + case "field": + i++ // fetch field index + tt := t.(*types.Struct) + idx := ops[i].(int) + offset := a.offsetOf(t, idx) + t = tt.Field(idx).Type().Underlying() + nid = a.addNodes(t, "query.extended") + a.copy(nid, pid+nodeid(offset), a.sizeof(t)) + case "arrayelem": + t = t.(*types.Array).Elem().Underlying() + nid = a.addNodes(t, "query.extended") + a.copy(nid, 1+pid, a.sizeof(t)) + case "sliceelem": + t = t.(*types.Slice).Elem().Underlying() + nid = a.addNodes(t, "query.extended") + a.load(nid, pid, 1, a.sizeof(t)) + case "mapelem": + tt := t.(*types.Map) + t = tt.Elem() + ksize := a.sizeof(tt.Key()) + vsize := a.sizeof(tt.Elem()) + nid = a.addNodes(t, "query.extended") + a.load(nid, pid, ksize, vsize) + case "index": + i++ // fetch index + tt := t.(*types.Tuple) + idx := ops[i].(int) + t = tt.At(idx).Type().Underlying() + nid = a.addNodes(t, "query.extended") + a.copy(nid, pid+nodeid(idx), a.sizeof(t)) + case "load": + t = t.(*types.Pointer).Elem().Underlying() + nid = a.addNodes(t, "query.extended") + a.load(nid, pid, 0, a.sizeof(t)) + default: + // shouldn't happen + panic(fmt.Sprintf("unknown op %q", ops[i])) + } + pid = nid + } + + return t, pid +} diff --git a/go/pointer/query_test.go b/go/pointer/query_test.go new file mode 100644 index 00000000..099cc542 --- /dev/null +++ b/go/pointer/query_test.go @@ -0,0 +1,68 @@ +package pointer + +import ( + "go/ast" + "go/parser" + "go/token" + "go/types" + "reflect" + "testing" +) + +func TestParseExtendedQuery(t *testing.T) { + const myprog = ` +package pkg +var V1 *int +var V2 **int +var V3 []*int +var V4 chan []*int +var V5 struct {F1, F2 chan *int} +var V6 [1]chan *int +var V7 int +` + tests := []struct { + in string + out []interface{} + v string + valid bool + }{ + {`x`, []interface{}{"x"}, "V1", true}, + {`*x`, []interface{}{"x", "load"}, "V2", true}, + {`x[0]`, []interface{}{"x", "sliceelem"}, "V3", true}, + {`<-x`, []interface{}{"x", "recv"}, "V4", true}, + {`(<-x)[0]`, []interface{}{"x", "recv", "sliceelem"}, "V4", true}, + {`<-x.F2`, []interface{}{"x", "field", 1, "recv"}, "V5", true}, + {`<-x[0]`, []interface{}{"x", "arrayelem", "recv"}, "V6", true}, + {`x`, nil, "V7", false}, + {`y`, nil, "V1", false}, + {`x; x`, nil, "V1", false}, + {`x()`, nil, "V1", false}, + {`close(x)`, nil, "V1", false}, + } + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "file.go", myprog, 0) + if err != nil { + t.Fatal(err) + } + cfg := &types.Config{} + pkg, err := cfg.Check("main", fset, []*ast.File{f}, nil) + if err != nil { + t.Fatal(err) + } + + for _, test := range tests { + typ := pkg.Scope().Lookup(test.v).Type().Underlying() + ops, _, err := parseExtendedQuery(typ, test.in) + if test.valid && err != nil { + t.Errorf("parseExtendedQuery(%q) = %s, expected no error", test.in, err) + } + if !test.valid && err == nil { + t.Errorf("parseExtendedQuery(%q) succeeded, expected error", test.in) + } + + if !reflect.DeepEqual(ops, test.out) { + t.Errorf("parseExtendedQuery(%q) = %#v, want %#v", test.in, ops, test.out) + } + } +} diff --git a/go/pointer/testdata/extended.go b/go/pointer/testdata/extended.go new file mode 100644 index 00000000..b3dd2030 --- /dev/null +++ b/go/pointer/testdata/extended.go @@ -0,0 +1,21 @@ +// +build ignore + +package main + +var a int + +type t struct { + a *map[string]chan *int +} + +func fn() []t { + m := make(map[string]chan *int) + m[""] = make(chan *int, 1) + m[""] <- &a + return []t{t{a: &m}} +} + +func main() { + x := fn() + print(x) // @pointstoquery <-(*x[i].a)[key] main.a +} diff --git a/go/pointer/util.go b/go/pointer/util.go index 2f184788..683fdddd 100644 --- a/go/pointer/util.go +++ b/go/pointer/util.go @@ -26,7 +26,6 @@ func CanPoint(T types.Type) bool { return true // treat reflect.Value like interface{} } return CanPoint(T.Underlying()) - case *types.Pointer, *types.Interface, *types.Map, *types.Chan, *types.Signature, *types.Slice: return true } @@ -171,7 +170,7 @@ func (a *analysis) flatten(t types.Type) []*fieldInfo { } default: - panic(t) + panic(fmt.Sprintf("cannot flatten unsupported type %T", t)) } a.flattenMemo[t] = fl