diff --git a/cmd/cover/cover.go b/cmd/cover/cover.go index b5f2b41c..26be0834 100644 --- a/cmd/cover/cover.go +++ b/cmd/cover/cover.go @@ -146,6 +146,38 @@ func (f *File) Visit(node ast.Node) ast.Visitor { } } n.List = f.addCounters(n.Lbrace, n.Rbrace+1, n.List, true) // +1 to step past closing brace. + case *ast.IfStmt: + ast.Walk(f, n.Body) + if n.Else == nil { + return nil + } + // The elses are special, because if we have + // if x { + // } else if y { + // } + // we want to cover the "if y". To do this, we need a place to drop the counter, + // so we add a hidden block: + // if x { + // } else { + // if y { + // } + // } + const backupToElse = token.Pos(len("else ")) // The AST doesn't remember the else location. We can make an accurate guess. + switch stmt := n.Else.(type) { + case *ast.IfStmt: + block := &ast.BlockStmt{ + Lbrace: stmt.If - backupToElse, // So the covered part looks like it starts at the "else". + List: []ast.Stmt{stmt}, + Rbrace: stmt.End(), + } + n.Else = block + case *ast.BlockStmt: + stmt.Lbrace -= backupToElse // So the block looks like it starts at the "else". + default: + panic("unexpected node type in if") + } + ast.Walk(f, n.Else) + return nil case *ast.SelectStmt: // Don't annotate an empty select - creates a syntax error. if n.Body == nil || len(n.Body.List) == 0 { @@ -507,7 +539,7 @@ func (f *File) addVariables(w io.Writer) { for i, block := range f.blocks { start := f.fset.Position(block.startByte) end := f.fset.Position(block.endByte) - fmt.Fprintf(w, "\t\t%d, %d, %#x, // %d\n", start.Line, end.Line, (end.Column&0xFFFF)<<16|(start.Column&0xFFFF), i) + fmt.Fprintf(w, "\t\t%d, %d, %#x, // [%d]\n", start.Line, end.Line, (end.Column&0xFFFF)<<16|(start.Column&0xFFFF), i) } // Close the position array. diff --git a/cmd/cover/testdata/main.go b/cmd/cover/testdata/main.go index 4ceff30e..c704b157 100644 --- a/cmd/cover/testdata/main.go +++ b/cmd/cover/testdata/main.go @@ -33,32 +33,61 @@ func check(line, count uint32) { counters[b] = true } +// checkVal is a version of check that returns its extra argument, +// so it can be used in conditionals. +func checkVal(line, count uint32, val int) int { + b := block{ + count, + line, + } + counters[b] = true + return val +} + +var PASS = true + // verify checks the expected counts against the actual. It runs after the test has completed. func verify() { - ok := true for b := range counters { - got := count(b.line) + got, index := count(b.line) if b.count == anything && got != 0 { got = anything } if got != b.count { - fmt.Fprintf(os.Stderr, "test_go:%d expected count %d got %d\n", b.line, b.count, got) - ok = false + fmt.Fprintf(os.Stderr, "test_go:%d expected count %d got %d [counter %d]\n", b.line, b.count, got, index) + PASS = false } } - if !ok { + if !PASS { fmt.Fprintf(os.Stderr, "FAIL\n") os.Exit(2) } } -func count(line uint32) uint32 { - // Linear search is fine. +// count returns the count and index for the counter at the specified line. +func count(line uint32) (uint32, int) { + // Linear search is fine. Choose perfect fit over approximate. + // We can have a closing brace for a range on the same line as a condition for an "else if" + // and we don't want that brace to steal the count for the condition on the "if". + // Therefore we test for a perfect (lo==line && hi==line) match, but if we can't + // find that we take the first imperfect match. + index := -1 + indexLo := uint32(1e9) for i := range coverTest.Count { lo, hi := coverTest.Pos[3*i], coverTest.Pos[3*i+1] - if lo <= line && line <= hi { - return coverTest.Count[i] + if lo == line && line == hi { + return coverTest.Count[i], i + } + // Choose the earliest match (the counters are in unpredictable order). + if lo <= line && line <= hi && indexLo > lo { + index = i + indexLo = lo } } - return 0 + if index == -1 { + fmt.Fprintln(os.Stderr, "cover_test: no counter for line", line) + PASS = false + return 0, 0 + } + return coverTest.Count[index], index } diff --git a/cmd/cover/testdata/test.go b/cmd/cover/testdata/test.go index a15a1288..0c5ea786 100644 --- a/cmd/cover/testdata/test.go +++ b/cmd/cover/testdata/test.go @@ -15,8 +15,10 @@ const anything = 1e9 // Just some unlikely value that means "we got here, don't func testAll() { testSimple() testBlockRun() + testIf() testFor() testSwitch() + testTypeSwitch() testSelect1() testSelect2() } @@ -25,6 +27,48 @@ func testSimple() { check(LINE, 1) } +func testIf() { + if true { + check(LINE, 1) + } else { + check(LINE, 0) + } + if false { + check(LINE, 0) + } else { + check(LINE, 1) + } + for i := 0; i < 3; i++ { + if checkVal(LINE, 3, i) <= 2 { + check(LINE, 3) + } + if checkVal(LINE, 3, i) <= 1 { + check(LINE, 2) + } + if checkVal(LINE, 3, i) <= 0 { + check(LINE, 1) + } + } + for i := 0; i < 3; i++ { + if checkVal(LINE, 3, i) <= 1 { + check(LINE, 2) + } else { + check(LINE, 1) + } + } + for i := 0; i < 3; i++ { + if checkVal(LINE, 3, i) <= 0 { + check(LINE, 1) + } else if checkVal(LINE, 2, i) <= 1 { + check(LINE, 1) + } else if checkVal(LINE, 1, i) <= 2 { + check(LINE, 1) + } else if checkVal(LINE, 0, i) <= 3 { + check(LINE, 0) + } + } +} + func testFor() { for i := 0; i < 10; i++ { check(LINE, 10) @@ -64,6 +108,24 @@ func testSwitch() { } } +func testTypeSwitch() { + var x = []interface{}{1, 2.0, "hi"} + for _, v := range x { + switch v.(type) { + case int: + check(LINE, 1) + case float64: + check(LINE, 1) + case string: + check(LINE, 1) + case complex128: + check(LINE, 0) + default: + check(LINE, 0) + } + } +} + func testSelect1() { c := make(chan int) go func() {