diff --git a/go/types/api.go b/go/types/api.go index 7b4b8c99..341c8e42 100644 --- a/go/types/api.go +++ b/go/types/api.go @@ -143,6 +143,5 @@ func IsAssignableTo(V, T Type) bool { // BUG(gri): Use of labels is only partially checked. // BUG(gri): Unused variables and imports are not reported. // BUG(gri): Interface vs non-interface comparisons are not correctly implemented. -// BUG(gri): Switch statements don't check correct use of 'fallthrough'. // BUG(gri): Switch statements don't check duplicate cases for all types for which it is required. // BUG(gri): Some built-ins may not be callable if in statement-context. diff --git a/go/types/resolver.go b/go/types/resolver.go index 04663c45..b4d547f1 100644 --- a/go/types/resolver.go +++ b/go/types/resolver.go @@ -358,7 +358,7 @@ func (check *checker) resolveFiles(files []*ast.File) { } check.topScope = f.sig.scope // open the function scope check.funcSig = f.sig - check.stmtList(f.body.List) + check.stmtList(f.body.List, false) if f.sig.results.Len() > 0 && !check.isTerminating(f.body, "") { check.errorf(f.body.Rbrace, "missing return") } diff --git a/go/types/stdlib_test.go b/go/types/stdlib_test.go index faaf8160..f032e9b9 100644 --- a/go/types/stdlib_test.go +++ b/go/types/stdlib_test.go @@ -63,9 +63,6 @@ func TestStdtest(t *testing.T) { case "sizeof.go", "switch.go": // TODO(gri) tone down duplicate checking in expression switches continue - case "switch4.go": - // TODO(gri) fix fallthrough checking - continue case "typeswitch2.go": // TODO(gri) implement duplicate checking in type switches continue diff --git a/go/types/stmt.go b/go/types/stmt.go index d65a2ef2..e0f70067 100644 --- a/go/types/stmt.go +++ b/go/types/stmt.go @@ -14,17 +14,17 @@ import ( func (check *checker) optionalStmt(s ast.Stmt) { if s != nil { scope := check.topScope - check.stmt(s) + check.stmt(s, false) assert(check.topScope == scope) } } -func (check *checker) stmtList(list []ast.Stmt) { - for _, s := range list { - scope := check.topScope - check.stmt(s) - assert(check.topScope == scope) +func (check *checker) stmtList(list []ast.Stmt, fallthroughOk bool) { + scope := check.topScope + for i, s := range list { + check.stmt(s, fallthroughOk && i+1 == len(list)) } + assert(check.topScope == scope) } func (check *checker) multipleDefaults(list []ast.Stmt) { @@ -66,7 +66,7 @@ func (check *checker) closeScope() { } // stmt typechecks statement s. -func (check *checker) stmt(s ast.Stmt) { +func (check *checker) stmt(s ast.Stmt, fallthroughOk bool) { // statements cannot use iota in general // (constant declarations set it explicitly) assert(check.iota == nil) @@ -86,7 +86,7 @@ func (check *checker) stmt(s ast.Stmt) { } label := s.Label check.declareObj(scope, label, NewLabel(label.Pos(), label.Name)) - check.stmt(s.Stmt) + check.stmt(s.Stmt, fallthroughOk) case *ast.ExprStmt: var x operand @@ -240,11 +240,28 @@ func (check *checker) stmt(s ast.Stmt) { } case *ast.BranchStmt: - // TODO(gri) implement this + switch s.Tok { + case token.BREAK: + // TODO(gri) implement checks + case token.CONTINUE: + // TODO(gri) implement checks + case token.GOTO: + // TODO(gri) implement checks + case token.FALLTHROUGH: + if s.Label != nil { + check.invalidAST(s.Label.Pos(), "fallthrough statement cannot have label") + // ok to continue + } + if !fallthroughOk { + check.errorf(s.Pos(), "fallthrough statement out of place") + } + default: + check.invalidAST(s.Pos(), "unknown branch statement (%s)", s.Tok) + } case *ast.BlockStmt: check.openScope(s) - check.stmtList(s.List) + check.stmtList(s.List, false) check.closeScope() case *ast.IfStmt: @@ -255,7 +272,7 @@ func (check *checker) stmt(s ast.Stmt) { if x.mode != invalid && !isBoolean(x.typ) { check.errorf(s.Cond.Pos(), "non-boolean condition in if statement") } - check.stmt(s.Body) + check.stmt(s.Body, false) check.optionalStmt(s.Else) check.closeScope() @@ -275,8 +292,8 @@ func (check *checker) stmt(s ast.Stmt) { check.multipleDefaults(s.Body.List) // TODO(gri) check also correct use of fallthrough seen := make(map[interface{}]token.Pos) - for _, s := range s.Body.List { - clause, _ := s.(*ast.CaseClause) + for i, c := range s.Body.List { + clause, _ := c.(*ast.CaseClause) if clause == nil { continue // error reported before } @@ -322,7 +339,7 @@ func (check *checker) stmt(s ast.Stmt) { } } check.openScope(clause) - check.stmtList(clause.Body) + check.stmtList(clause.Body, i+1 < len(s.Body.List)) check.closeScope() } check.closeScope() @@ -420,7 +437,7 @@ func (check *checker) stmt(s ast.Stmt) { check.declareObj(check.topScope, nil, obj) check.recordImplicit(clause, obj) } - check.stmtList(clause.Body) + check.stmtList(clause.Body, false) check.closeScope() } @@ -433,7 +450,7 @@ func (check *checker) stmt(s ast.Stmt) { } check.openScope(clause) check.optionalStmt(clause.Comm) // TODO(gri) check correctness of c.Comm (must be Send/RecvStmt) - check.stmtList(clause.Body) + check.stmtList(clause.Body, false) check.closeScope() } @@ -448,7 +465,7 @@ func (check *checker) stmt(s ast.Stmt) { } } check.optionalStmt(s.Post) - check.stmt(s.Body) + check.stmt(s.Body, false) check.closeScope() case *ast.RangeStmt: @@ -463,7 +480,7 @@ func (check *checker) stmt(s ast.Stmt) { // if we don't have a declaration, we can still check the loop's body // (otherwise we can't because we are missing the declared variables) if !decl { - check.stmt(s.Body) + check.stmt(s.Body, false) } return } @@ -506,7 +523,7 @@ func (check *checker) stmt(s ast.Stmt) { check.errorf(x.pos(), "cannot range over %s", &x) // if we don't have a declaration, we can still check the loop's body if !decl { - check.stmt(s.Body) + check.stmt(s.Body, false) } return } @@ -568,7 +585,7 @@ func (check *checker) stmt(s ast.Stmt) { } } - check.stmt(s.Body) + check.stmt(s.Body, false) default: check.errorf(s.Pos(), "invalid statement") diff --git a/go/types/testdata/stmt0.src b/go/types/testdata/stmt0.src index f00e98c6..2fb29e9b 100644 --- a/go/types/testdata/stmt0.src +++ b/go/types/testdata/stmt0.src @@ -127,9 +127,12 @@ func defers() { defer len(c) // TODO(gri) this should not be legal } -func switches() { +func switches0() { var x int + switch x { + } + switch x { default: default /* ERROR "multiple defaults" */ : @@ -162,6 +165,53 @@ func switches() { } } +func switches1() { + fallthrough /* ERROR "fallthrough statement out of place" */ + + var x int + switch x { + case 0: + fallthrough /* ERROR "fallthrough statement out of place" */ + break + case 1: + fallthrough + case 2: + default: + fallthrough + case 3: + fallthrough /* ERROR "fallthrough statement out of place" */ + } + + var y interface{} + switch y.(type) { + case int: + fallthrough /* ERROR "fallthrough statement out of place" */ + default: + } + + switch x { + case 0: + if x == 0 { + fallthrough /* ERROR "fallthrough statement out of place" */ + } + } + + switch x { + case 0: + L1: fallthrough + case 1: + L2: L3: L4: fallthrough + default: + } + + switch x { + case 0: + L5: fallthrough + default: + L6: L7: L8: fallthrough /* ERROR "fallthrough statement out of place" */ + } +} + type I interface { m() }