diff --git a/cmd/cover/cover.go b/cmd/cover/cover.go index e4f7a421..d6ce70a8 100644 --- a/cmd/cover/cover.go +++ b/cmd/cover/cover.go @@ -25,10 +25,11 @@ import ( "log" "os" "sort" + "strconv" ) var ( - mode = flag.String("mode", "set", "coverage mode: set, sum, atomic") + mode = flag.String("mode", "set", "coverage mode: set, count, atomic") countVar = flag.String("count", "__count", "name of coverage count array variable") posVar = flag.String("pos", "__pos", "name of coverage count position variable") ) @@ -36,8 +37,8 @@ var ( var counterStmt func(*File, ast.Expr) ast.Stmt const ( - coveragePackagePath = "code.google.com/p/go.tools/coverage" - atomicPackagePath = "sync/atomic" + atomicPackagePath = "sync/atomic" + atomicPackageName = "_cover_atomic_" ) func usage() { @@ -52,7 +53,7 @@ func main() { switch *mode { case "set": counterStmt = setCounterStmt - case "sum": + case "count": counterStmt = incCounterStmt case "atomic": counterStmt = atomicCounterStmt @@ -77,12 +78,11 @@ type Block struct { // File is a wrapper for the state of a file used in the parser. // The basic parse tree walker is a method of this type. type File struct { - fset *token.FileSet - name string // Name of file. - astFile *ast.File - blocks []Block - coveragePkg string // Package name for ".../coverage" in this file. - atomicPkg string // Package name for "sync/atomic" in this file. + fset *token.FileSet + name string // Name of file. + astFile *ast.File + blocks []Block + atomicPkg string // Package name for "sync/atomic" in this file. } // Visit implements the ast.Visitor interface. @@ -111,6 +111,67 @@ func (f *File) Visit(node ast.Node) ast.Visitor { return f } +// unquote returns the unquoted string. +func unquote(s string) string { + t, err := strconv.Unquote(s) + if err != nil { + log.Fatal("cover: improperly quoted string %q\n", s) + } + return t +} + +// addImport adds an import for the specified path, if one does not already exist, and returns +// the local package name. +func (f *File) addImport(path string) string { + // Does the package already import it? + for _, s := range f.astFile.Imports { + if unquote(s.Path.Value) == path { + return s.Name.Name + } + } + newImport := &ast.ImportSpec{ + Name: ast.NewIdent(atomicPackageName), + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: fmt.Sprintf("%q", path), + }, + } + impDecl := &ast.GenDecl{ + Tok: token.IMPORT, + Specs: []ast.Spec{ + newImport, + }, + } + // Make the new import the first Decl in the file. + astFile := f.astFile + astFile.Decls = append(astFile.Decls, nil) + copy(astFile.Decls[1:], astFile.Decls[0:]) + astFile.Decls[0] = impDecl + astFile.Imports = append(astFile.Imports, newImport) + + // Now refer to the package, just in case it ends up unused. + // That is, append to the end of the file the declaration + // var _ = _cover_atomic_.AddUint32 + reference := &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{ + &ast.ValueSpec{ + Names: []*ast.Ident{ + ast.NewIdent("_"), + }, + Values: []ast.Expr{ + &ast.SelectorExpr{ + X: ast.NewIdent(atomicPackageName), + Sel: ast.NewIdent("AddUint32"), + }, + }, + }, + }, + } + astFile.Decls = append(astFile.Decls, reference) + return atomicPackageName +} + func cover(name string) { var files []*File var astFiles []*ast.File @@ -136,6 +197,9 @@ func cover(name string) { files = append(files, thisFile) astFiles = append(astFiles, parsedFile) for _, file := range files { + if *mode == "atomic" { + file.atomicPkg = file.addImport(atomicPackagePath) + } ast.Walk(file, file.astFile) file.print(os.Stdout) // After printing the source tree, add some declarations for the counters etc. @@ -379,11 +443,10 @@ func (f *File) addVariables(w io.Writer) { // - 32-bit starting line number // - 32-bit ending line number // - (16 bit ending column number << 16) | (16-bit starting column number). - for i, block := range f.blocks { + for _, block := range f.blocks { start := f.fset.Position(block.startByte) end := f.fset.Position(block.endByte) fmt.Fprintf(w, "\t%d, %d, %#x,\n", start.Line, end.Line, (end.Column&0xFFFF)<<16|(start.Column&0xFFFF)) - fmt.Fprintf(w, "//FOR DEBUGGING: \t%d: %s:#%d,#%d\n", i, f.name, block.startByte, block.endByte) } // Close the declaration.