diff --git a/go/types/decl.go b/go/types/decl.go index 5d170862..c5d89905 100644 --- a/go/types/decl.go +++ b/go/types/decl.go @@ -151,6 +151,26 @@ func (check *checker) varDecl(obj *Var, lhs []*Var, typ, init ast.Expr) { check.initVars(lhs, []ast.Expr{init}, token.NoPos) } +// underlying returns the underlying type of typ; possibly by following +// forward chains of named types. Such chains only exist while names types +// are incomplete. +func underlying(typ Type) Type { + for { + n, _ := typ.(*Named) + if n == nil { + break + } + typ = n.underlying + } + return typ +} + +func (n *Named) setUnderlying(typ Type) { + if n != nil { + n.underlying = typ + } +} + func (check *checker) typeDecl(obj *TypeName, typ ast.Expr, def *Named, cycleOk bool) { assert(obj.Type() == nil) @@ -158,17 +178,14 @@ func (check *checker) typeDecl(obj *TypeName, typ ast.Expr, def *Named, cycleOk assert(check.iota == nil) named := &Named{obj: obj} + def.setUnderlying(named) obj.typ = named // make sure recursive type declarations terminate - // If this type (named) defines the type of another (def) type declaration, - // set def's underlying type to this type so that we can resolve the true - // underlying of def later. - if def != nil { - def.underlying = named - } + // determine underlying type of named + check.typ(typ, named, cycleOk) - // Typecheck typ - it may be a named type that is not yet complete. - // For instance, consider: + // The underlying type of named may be itself a named type that is + // incomplete: // // type ( // A B @@ -176,21 +193,11 @@ func (check *checker) typeDecl(obj *TypeName, typ ast.Expr, def *Named, cycleOk // C A // ) // - // When we declare object C, typ is the identifier A which is incomplete. - u := check.typ(typ, named, cycleOk) - - // Determine the unnamed underlying type. - // In the above example, the underlying type of A was (temporarily) set - // to B whose underlying type was set to *C. Such "forward chains" always - // end in an unnamed type (cycles are terminated with an invalid type). - for { - n, _ := u.(*Named) - if n == nil { - break - } - u = n.underlying - } - named.underlying = u + // The type of C is the (named) type of A which is incomplete, + // and which has as its underlying type the named type B. + // Determine the (final, unnamed) underlying type by resolving + // any forward chain (they always end in an unnamed type). + named.underlying = underlying(named.underlying) // the underlying type has been determined named.complete = true @@ -257,14 +264,14 @@ func (check *checker) funcDecl(obj *Func, info *declInfo) { // func declarations cannot use iota assert(check.iota == nil) - obj.typ = Typ[Invalid] // guard against cycles + sig := new(Signature) + obj.typ = sig // guard against cycles fdecl := info.fdecl - sig := check.funcType(fdecl.Recv, fdecl.Type, nil) + check.funcType(sig, fdecl.Recv, fdecl.Type) if sig.recv == nil && obj.name == "init" && (sig.params.Len() > 0 || sig.results.Len() > 0) { check.errorf(fdecl.Pos(), "func init must have no arguments and no return values") // ok to continue } - obj.typ = sig // function body must be type-checked after global declarations // (functions implemented elsewhere have no body) diff --git a/go/types/typexpr.go b/go/types/typexpr.go index 0ca17a88..c8cbd4ed 100644 --- a/go/types/typexpr.go +++ b/go/types/typexpr.go @@ -36,7 +36,7 @@ func (check *checker) ident(x *operand, e *ast.Ident, def *Named, cycleOk bool) typ := obj.Type() if typ == nil { - // object not yet declared + // object type not yet determined if check.objMap == nil { check.dump("%s: %s should have been declared (we are inside a function)", e.Pos(), e) unreachable() @@ -82,9 +82,6 @@ func (check *checker) ident(x *operand, e *ast.Ident, def *Named, cycleOk bool) // maintain x.mode == typexpr despite error typ = Typ[Invalid] } - if def != nil { - def.underlying = typ - } case *Var: obj.used = true @@ -141,12 +138,7 @@ func (check *checker) typ(e ast.Expr, def *Named, cycleOk bool) (T Type) { } // funcType type-checks a function or method type and returns its signature. -func (check *checker) funcType(recv *ast.FieldList, ftyp *ast.FuncType, def *Named) *Signature { - sig := new(Signature) - if def != nil { - def.underlying = sig - } - +func (check *checker) funcType(sig *Signature, recv *ast.FieldList, ftyp *ast.FuncType) *Signature { scope := NewScope(check.topScope) check.recordScope(ftyp, scope) @@ -202,7 +194,7 @@ func (check *checker) funcType(recv *ast.FieldList, ftyp *ast.FuncType, def *Nam return sig } -// typInternal contains the core of type checking of types. +// typInternal drives type checking of types. // Must only be called by typ. // func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type { @@ -216,7 +208,9 @@ func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type { switch x.mode { case typexpr: - return x.typ + typ := x.typ + def.setUnderlying(typ) + return typ case invalid: // ignore - error reported before case novalue: @@ -231,7 +225,9 @@ func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type { switch x.mode { case typexpr: - return x.typ + typ := x.typ + def.setUnderlying(typ) + return typ case invalid: // ignore - error reported before case novalue: @@ -246,53 +242,45 @@ func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type { case *ast.ArrayType: if e.Len != nil { typ := new(Array) - if def != nil { - def.underlying = typ - } - + def.setUnderlying(typ) typ.len = check.arrayLength(e.Len) typ.elem = check.typ(e.Elt, nil, cycleOk) return typ } else { typ := new(Slice) - if def != nil { - def.underlying = typ - } - + def.setUnderlying(typ) typ.elem = check.typ(e.Elt, nil, true) return typ } case *ast.StructType: typ := new(Struct) - if def != nil { - def.underlying = typ - } - - typ.fields, typ.tags = check.collectFields(e.Fields, cycleOk) + def.setUnderlying(typ) + check.structType(typ, e, cycleOk) return typ case *ast.StarExpr: typ := new(Pointer) - if def != nil { - def.underlying = typ - } - + def.setUnderlying(typ) typ.base = check.typ(e.X, nil, true) return typ case *ast.FuncType: - return check.funcType(nil, e, def) + typ := new(Signature) + def.setUnderlying(typ) + check.funcType(typ, nil, e) + return typ case *ast.InterfaceType: - return check.interfaceType(e, def, cycleOk) + typ := new(Interface) + def.setUnderlying(typ) + check.interfaceType(typ, e, def, cycleOk) + return typ case *ast.MapType: typ := new(Map) - if def != nil { - def.underlying = typ - } + def.setUnderlying(typ) typ.key = check.typ(e.Key, nil, true) typ.elem = check.typ(e.Value, nil, true) @@ -313,9 +301,7 @@ func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type { case *ast.ChanType: typ := new(Chan) - if def != nil { - def.underlying = typ - } + def.setUnderlying(typ) dir := SendRecv switch e.Dir { @@ -329,6 +315,7 @@ func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type { check.invalidAST(e.Pos(), "unknown channel direction %d", e.Dir) // ok to continue } + typ.dir = dir typ.elem = check.typ(e.Value, nil, true) return typ @@ -337,7 +324,9 @@ func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type { check.errorf(e.Pos(), "%s is not a type", e) } - return Typ[Invalid] + typ := Typ[Invalid] + def.setUnderlying(typ) + return typ } // typeOrNil type-checks the type expression (or nil value) e @@ -438,15 +427,10 @@ func (check *checker) declareInSet(oset *objset, pos token.Pos, obj Object) bool return true } -func (check *checker) interfaceType(ityp *ast.InterfaceType, def *Named, cycleOk bool) *Interface { - iface := new(Interface) - if def != nil { - def.underlying = iface - } - +func (check *checker) interfaceType(iface *Interface, ityp *ast.InterfaceType, def *Named, cycleOk bool) { // empty interface: common case if ityp.Methods == nil { - return iface + return } // The parser ensures that field tags are nil and we don't @@ -501,31 +485,21 @@ func (check *checker) interfaceType(ityp *ast.InterfaceType, def *Named, cycleOk for _, e := range embedded { pos := e.Pos() typ := check.typ(e, nil, cycleOk) - if typ == Typ[Invalid] { - continue - } named, _ := typ.(*Named) if named == nil { - check.invalidAST(pos, "%s is not named type", typ) + if typ != Typ[Invalid] { + check.invalidAST(pos, "%s is not named type", typ) + } continue } // determine underlying (possibly incomplete) type // by following its forward chain - // TODO(gri) should this be part of Underlying()? - u := named.underlying - for { - n, _ := u.(*Named) - if n == nil { - break - } - u = n.underlying - } - if u == Typ[Invalid] { - continue - } + u := underlying(named) embed, _ := u.(*Interface) if embed == nil { - check.errorf(pos, "%s is not an interface", named) + if u != Typ[Invalid] { + check.errorf(pos, "%s is not an interface", named) + } continue } iface.embeddeds = append(iface.embeddeds, named) @@ -553,12 +527,11 @@ func (check *checker) interfaceType(ityp *ast.InterfaceType, def *Named, cycleOk for i, m := range iface.methods { expr := signatures[i] typ := check.typ(expr, nil, true) - if typ == Typ[Invalid] { - continue // keep method with empty method signature - } sig, _ := typ.(*Signature) if sig == nil { - check.invalidAST(expr.Pos(), "%s is not a method signature", typ) + if typ != Typ[Invalid] { + check.invalidAST(expr.Pos(), "%s is not a method signature", typ) + } continue // keep method with empty method signature } sig.recv = NewVar(m.pos, check.pkg, "", recv) @@ -576,8 +549,6 @@ func (check *checker) interfaceType(ityp *ast.InterfaceType, def *Named, cycleOk sort.Sort(byUniqueTypeName(iface.embeddeds)) sort.Sort(byUniqueMethodName(iface.allMethods)) - - return iface } // byUniqueTypeName named type lists can be sorted by their unique type names. @@ -606,15 +577,22 @@ func (check *checker) tag(t *ast.BasicLit) string { return "" } -func (check *checker) collectFields(list *ast.FieldList, cycleOk bool) (fields []*Var, tags []string) { +func (check *checker) structType(styp *Struct, e *ast.StructType, cycleOk bool) { + list := e.Fields if list == nil { return } + // struct fields and tags + var fields []*Var + var tags []string + + // for double-declaration checks var fset objset - var typ Type // current field typ - var tag string // current field tag + // current field typ and tag + var typ Type + var tag string add := func(field *ast.Field, ident *ast.Ident, name string, anonymous bool, pos token.Pos) { if tag != "" && tags == nil { tags = make([]string, len(fields)) @@ -662,7 +640,7 @@ func (check *checker) collectFields(list *ast.FieldList, cycleOk bool) (fields [ // spec: "An embedded type must be specified as a type name // T or as a pointer to a non-interface type name *T, and T // itself may not be a pointer type." - switch u := t.Underlying().(type) { + switch u := t.underlying.(type) { case *Basic: // unsafe.Pointer is treated like a regular pointer if u.kind == UnsafePointer { @@ -686,5 +664,6 @@ func (check *checker) collectFields(list *ast.FieldList, cycleOk bool) (fields [ } } - return + styp.fields = fields + styp.tags = tags }