diff --git a/pointer/analysis.go b/pointer/analysis.go index db04ef86..e3677c69 100644 --- a/pointer/analysis.go +++ b/pointer/analysis.go @@ -94,7 +94,8 @@ type node struct { // - *loadConstraint y=*x // - *offsetAddrConstraint y=&x.f or y=&x[0] // - *storeConstraint *x=z - // - *typeAssertConstraint y=x.(T) + // - *typeFilterConstraint y=x.(I) + // - *untagConstraint y=x.(C) // - *invokeConstraint y=x.f(params...) complex constraintset } @@ -150,14 +151,32 @@ type offsetAddrConstraint struct { src nodeid // (ptr) } -// dst = src.(typ) +// dst = src.(typ) where typ is an interface // A complex constraint attached to src (the interface). -type typeAssertConstraint struct { - typ types.Type +// No representation change: pts(dst) and pts(src) contains tagged objects. +type typeFilterConstraint struct { + typ types.Type // an interface type dst nodeid src nodeid // (ptr) } +// dst = src.(typ) where typ is a concrete type +// A complex constraint attached to src (the interface). +// +// If exact, only tagged objects identical to typ are untagged. +// If !exact, tagged objects assignable to typ are untagged too. +// The latter is needed for various reflect operators, e.g. Send. +// +// This entails a representation change: +// pts(src) contains tagged objects, +// pts(dst) contains their payloads. +type untagConstraint struct { + typ types.Type // a concrete type + dst nodeid + src nodeid // (ptr) + exact bool +} + // src.method(params...) // A complex constraint attached to iface. type invokeConstraint struct { diff --git a/pointer/gen.go b/pointer/gen.go index 8c82cf37..72e7e66c 100644 --- a/pointer/gen.go +++ b/pointer/gen.go @@ -351,9 +351,14 @@ func (a *analysis) offsetAddr(dst, src nodeid, offset uint32) { } } -// typeAssert creates a typeAssert constraint of the form dst = src.(T). -func (a *analysis) typeAssert(T types.Type, dst, src nodeid) { - a.addConstraint(&typeAssertConstraint{T, dst, src}) +// typeFilter creates a typeFilter constraint of the form dst = src.(I). +func (a *analysis) typeFilter(I types.Type, dst, src nodeid) { + a.addConstraint(&typeFilterConstraint{I, dst, src}) +} + +// untag creates an untag constraint of the form dst = src.(C). +func (a *analysis) untag(C types.Type, dst, src nodeid, exact bool) { + a.addConstraint(&untagConstraint{C, dst, src, exact}) } // addConstraint adds c to the constraint set. @@ -687,7 +692,7 @@ func (a *analysis) genInvokeReflectType(caller *cgnode, site *callsite, call *ss // Unpack receiver into rtype rtype := a.addOneNode(a.reflectRtypePtr, "rtype.recv", nil) recv := a.valueNode(call.Value) - a.typeAssert(a.reflectRtypePtr, rtype, recv) + a.untag(a.reflectRtypePtr, rtype, recv, true) // Look up the concrete method. meth := a.reflectRtypePtr.MethodSet().Lookup(call.Method.Pkg(), call.Method.Name()) @@ -801,7 +806,7 @@ func (a *analysis) objectNode(cgn *cgnode, v ssa.Value) nodeid { } if a.log != nil { - fmt.Fprintf(a.log, "\tglobalobj[%s] = n%d\n", a.nodes[obj].obj, obj) + fmt.Fprintf(a.log, "\tglobalobj[%s] = n%d\n", v, obj) } a.globalobj[v] = obj } @@ -870,7 +875,7 @@ func (a *analysis) objectNode(cgn *cgnode, v ssa.Value) nodeid { } if a.log != nil { - fmt.Fprintf(a.log, "\tlocalobj[%s] = n%d\n", a.nodes[obj].obj, obj) + fmt.Fprintf(a.log, "\tlocalobj[%s] = n%d\n", v.Name(), obj) } a.localobj[v] = obj } @@ -1005,7 +1010,12 @@ func (a *analysis) genInstr(cgn *cgnode, instr ssa.Instruction) { a.copy(a.valueNode(instr), a.valueNode(instr.X), 1) case *ssa.TypeAssert: - a.typeAssert(instr.AssertedType, a.valueNode(instr), a.valueNode(instr.X)) + T := instr.AssertedType + if _, ok := T.Underlying().(*types.Interface); ok { + a.typeFilter(T, a.valueNode(instr), a.valueNode(instr.X)) + } else { + a.untag(T, a.valueNode(instr), a.valueNode(instr.X), true) + } case *ssa.Slice: a.copy(a.valueNode(instr), a.valueNode(instr.X), 1) diff --git a/pointer/print.go b/pointer/print.go index 9184d3ab..a0b11e8a 100644 --- a/pointer/print.go +++ b/pointer/print.go @@ -26,8 +26,12 @@ func (c *offsetAddrConstraint) String() string { return fmt.Sprintf("offsetAddr n%d <- n%d.#%d", c.dst, c.src, c.offset) } -func (c *typeAssertConstraint) String() string { - return fmt.Sprintf("typeAssert n%d <- n%d.(%s)", c.dst, c.src, c.typ) +func (c *typeFilterConstraint) String() string { + return fmt.Sprintf("typeFilter n%d <- n%d.(%s)", c.dst, c.src, c.typ) +} + +func (c *untagConstraint) String() string { + return fmt.Sprintf("untag n%d <- n%d.(%s)", c.dst, c.src, c.typ) } func (c *invokeConstraint) String() string { diff --git a/pointer/reflect.go b/pointer/reflect.go index 51e49f5c..39eb717f 100644 --- a/pointer/reflect.go +++ b/pointer/reflect.go @@ -7,12 +7,6 @@ package pointer // For consistency, the names of all parameters match those of the // actual functions in the "reflect" package. // -// TODO(adonovan): fix: most of the reflect API permits implicit -// conversions due to assignability, e.g. m.MapIndex(k) is ok if T(k) -// is assignable to T(M).key. It's not yet clear how best to model -// that; perhaps a more lenient version of typeAssertConstraint is -// needed. -// // To avoid proliferation of equivalent labels, instrinsics should // memoize as much as possible, like TypeOf and Zero do for their // tagged objects. @@ -440,7 +434,7 @@ func (c *rVSendConstraint) solve(a *analysis, _ *node, delta nodeset) { // Extract x's payload to xtmp, then store to channel. tElem := tChan.Elem() xtmp := a.addNodes(tElem, "Send.xtmp") - a.typeAssert(tElem, xtmp, c.x) + a.untag(tElem, xtmp, c.x, false) a.store(ch, xtmp, 0, a.sizeof(tElem)) } } @@ -535,12 +529,12 @@ func (c *rVSetMapIndexConstraint) solve(a *analysis, _ *node, delta nodeset) { // Extract key's payload to keytmp, then store to map key. keytmp := a.addNodes(tMap.Key(), "SetMapIndex.keytmp") - a.typeAssert(tMap.Key(), keytmp, c.key) + a.untag(tMap.Key(), keytmp, c.key, false) a.store(m, keytmp, 0, keysize) // Extract val's payload to vtmp, then store to map value. valtmp := a.addNodes(tMap.Elem(), "SetMapIndex.valtmp") - a.typeAssert(tMap.Elem(), valtmp, c.val) + a.untag(tMap.Elem(), valtmp, c.val, false) a.store(m, valtmp, keysize, a.sizeof(tMap.Elem())) } } diff --git a/pointer/solve.go b/pointer/solve.go index 0cb8b5ba..66e9be0c 100644 --- a/pointer/solve.go +++ b/pointer/solve.go @@ -173,7 +173,10 @@ func (c *loadConstraint) ptr() nodeid { func (c *offsetAddrConstraint) ptr() nodeid { return c.src } -func (c *typeAssertConstraint) ptr() nodeid { +func (c *typeFilterConstraint) ptr() nodeid { + return c.src +} +func (c *untagConstraint) ptr() nodeid { return c.src } func (c *invokeConstraint) ptr() nodeid { @@ -251,9 +254,31 @@ func (c *offsetAddrConstraint) solve(a *analysis, n *node, delta nodeset) { } } -func (c *typeAssertConstraint) solve(a *analysis, n *node, delta nodeset) { - tIface, _ := c.typ.Underlying().(*types.Interface) +func (c *typeFilterConstraint) solve(a *analysis, n *node, delta nodeset) { + for ifaceObj := range delta { + tDyn, _, indirect := a.taggedValue(ifaceObj) + if tDyn == nil { + panic("not a tagged value") + } + if indirect { + // TODO(adonovan): we'll need to implement this + // when we start creating indirect tagged objects. + panic("indirect tagged object") + } + if types.IsAssignableTo(tDyn, c.typ) { + if a.addLabel(c.dst, ifaceObj) { + a.addWork(c.dst) + } + } + } +} + +func (c *untagConstraint) solve(a *analysis, n *node, delta nodeset) { + predicate := types.IsAssignableTo + if c.exact { + predicate = types.IsIdentical + } for ifaceObj := range delta { tDyn, v, indirect := a.taggedValue(ifaceObj) if tDyn == nil { @@ -265,23 +290,14 @@ func (c *typeAssertConstraint) solve(a *analysis, n *node, delta nodeset) { panic("indirect tagged object") } - if tIface != nil { - if types.IsAssignableTo(tDyn, tIface) { - if a.addLabel(c.dst, ifaceObj) { - a.addWork(c.dst) - } - } - } else { - if types.IsIdentical(tDyn, c.typ) { - // Copy entire payload to dst. - // - // TODO(adonovan): opt: if tConc is - // nonpointerlike we can skip this - // entire constraint, perhaps. We - // only care about pointers among the - // fields. - a.onlineCopyN(c.dst, v, a.sizeof(tDyn)) - } + if predicate(tDyn, c.typ) { + // Copy payload sans tag to dst. + // + // TODO(adonovan): opt: if tConc is + // nonpointerlike we can skip this entire + // constraint, perhaps. We only care about + // pointers among the fields. + a.onlineCopyN(c.dst, v, a.sizeof(tDyn)) } } } diff --git a/pointer/testdata/mapreflect.go b/pointer/testdata/mapreflect.go index 721d5e43..f1f14a24 100644 --- a/pointer/testdata/mapreflect.go +++ b/pointer/testdata/mapreflect.go @@ -64,6 +64,29 @@ func reflectSetMapIndex() { print(reflect.Zero(tmap.Elem()).Interface()) // @types *bool } +func reflectSetMapIndexAssignable() { + // SetMapIndex performs implicit assignability conversions. + type I *int + type J *int + + str := reflect.ValueOf("") + + // *int is assignable to I. + m1 := make(map[string]I) + reflect.ValueOf(m1).SetMapIndex(str, reflect.ValueOf(new(int))) // @line int + print(m1[""]) // @pointsto new@int:58 + + // I is assignable to I. + m2 := make(map[string]I) + reflect.ValueOf(m2).SetMapIndex(str, reflect.ValueOf(I(new(int)))) // @line I + print(m2[""]) // @pointsto new@I:60 + + // J is not assignable to I. + m3 := make(map[string]I) + reflect.ValueOf(m3).SetMapIndex(str, reflect.ValueOf(J(new(int)))) + print(m3[""]) // @pointsto +} + func reflectMakeMap() { t := reflect.TypeOf(map[*int]*bool(nil)) v := reflect.MakeMap(t) @@ -74,6 +97,7 @@ func reflectMakeMap() { func main() { reflectMapKeysIndex() reflectSetMapIndex() + reflectSetMapIndexAssignable() reflectMakeMap() // TODO(adonovan): reflect.MapOf(Type) }