diff --git a/cmd/callgraph/main.go b/cmd/callgraph/main.go index eed92aaf..3a5a2867 100644 --- a/cmd/callgraph/main.go +++ b/cmd/callgraph/main.go @@ -221,12 +221,12 @@ func doCallgraph(ctxt *build.Context, algo, format string, tests bool, args []st } } - main, err := mainPackage(prog, tests) + mains, err := mainPackages(prog, tests) if err != nil { return err } config := &pointer.Config{ - Mains: []*ssa.Package{main}, + Mains: mains, BuildCallGraph: true, Log: ptalog, } @@ -237,13 +237,13 @@ func doCallgraph(ctxt *build.Context, algo, format string, tests bool, args []st cg = ptares.CallGraph case "rta": - main, err := mainPackage(prog, tests) + mains, err := mainPackages(prog, tests) if err != nil { return err } - roots := []*ssa.Function{ - main.Func("init"), - main.Func("main"), + var roots []*ssa.Function + for _, main := range mains { + roots = append(roots, main.Func("init"), main.Func("main")) } rtares := rta.Analyze(roots, true) cg = rtares.CallGraph @@ -303,35 +303,31 @@ func doCallgraph(ctxt *build.Context, algo, format string, tests bool, args []st return nil } -// mainPackage returns the main package to analyze. -// The resulting package has a main() function. -func mainPackage(prog *ssa.Program, tests bool) (*ssa.Package, error) { - pkgs := prog.AllPackages() - - // TODO(adonovan): allow independent control over tests, mains and libraries. - // TODO(adonovan): put this logic in a library; we keep reinventing it. +// mainPackages returns the main packages to analyze. +// Each resulting package is named "main" and has a main function. +func mainPackages(prog *ssa.Program, tests bool) ([]*ssa.Package, error) { + pkgs := prog.AllPackages() // TODO(adonovan): use only initial packages + // If tests, create a "testmain" package for each test. + var mains []*ssa.Package if tests { - // If -test, use all packages' tests. - if len(pkgs) > 0 { - if main := prog.CreateTestMainPackage(pkgs...); main != nil { - return main, nil + for _, pkg := range pkgs { + if main := prog.CreateTestMainPackage(pkg); main != nil { + mains = append(mains, main) } } - return nil, fmt.Errorf("no tests") - } - - // Otherwise, use the first package named main. - for _, pkg := range pkgs { - if pkg.Pkg.Name() == "main" { - if pkg.Func("main") == nil { - return nil, fmt.Errorf("no func main() in main package") - } - return pkg, nil + if mains == nil { + return nil, fmt.Errorf("no tests") } + return mains, nil } - return nil, fmt.Errorf("no main package") + // Otherwise, use the main packages. + mains = append(mains, ssautil.MainPackages(pkgs)...) + if len(mains) == 0 { + return nil, fmt.Errorf("no main packages") + } + return mains, nil } type Edge struct { diff --git a/cmd/callgraph/main_test.go b/cmd/callgraph/main_test.go index f1f7166d..c42f56da 100644 --- a/cmd/callgraph/main_test.go +++ b/cmd/callgraph/main_test.go @@ -47,14 +47,14 @@ func TestCallgraph(t *testing.T) { }}, // tests: main is not called. {"rta", format, true, []string{ + `pkg$testmain.init --> pkg.init`, `pkg.Example --> (pkg.C).f`, - `test$main.init --> pkg.init`, }}, {"pta", format, true, []string{ + ` --> pkg$testmain.init`, ` --> pkg.Example`, - ` --> test$main.init`, + `pkg$testmain.init --> pkg.init`, `pkg.Example --> (pkg.C).f`, - `test$main.init --> pkg.init`, }}, } { stdout = new(bytes.Buffer) diff --git a/cmd/guru/guru.go b/cmd/guru/guru.go index 9e631130..25457fe0 100644 --- a/cmd/guru/guru.go +++ b/cmd/guru/guru.go @@ -129,26 +129,18 @@ func setPTAScope(lconf *loader.Config, scope []string) error { // Create a pointer.Config whose scope is the initial packages of lprog // and their dependencies. func setupPTA(prog *ssa.Program, lprog *loader.Program, ptaLog io.Writer, reflection bool) (*pointer.Config, error) { - // TODO(adonovan): the body of this function is essentially - // duplicated in all go/pointer clients. Refactor. - // For each initial package (specified on the command line), // if it has a main function, analyze that, // otherwise analyze its tests, if any. - var testPkgs, mains []*ssa.Package + var mains []*ssa.Package for _, info := range lprog.InitialPackages() { - initialPkg := prog.Package(info.Pkg) + p := prog.Package(info.Pkg) // Add package to the pointer analysis scope. - if initialPkg.Func("main") != nil { - mains = append(mains, initialPkg) - } else { - testPkgs = append(testPkgs, initialPkg) - } - } - if testPkgs != nil { - if p := prog.CreateTestMainPackage(testPkgs...); p != nil { + if p.Pkg.Name() == "main" && p.Func("main") != nil { mains = append(mains, p) + } else if main := prog.CreateTestMainPackage(p); main != nil { + mains = append(mains, main) } } if mains == nil { diff --git a/cmd/ssadump/main.go b/cmd/ssadump/main.go index bf2bbf24..9a0ba7d5 100644 --- a/cmd/ssadump/main.go +++ b/cmd/ssadump/main.go @@ -125,46 +125,44 @@ func doMain() error { } // Load, parse and type-check the whole program. - iprog, err := conf.Load() + lprog, err := conf.Load() if err != nil { return err } // Create and build SSA-form program representation. - prog := ssautil.CreateProgram(iprog, mode) + prog := ssautil.CreateProgram(lprog, mode) // Build and display only the initial packages // (and synthetic wrappers), unless -run is specified. - for _, info := range iprog.InitialPackages() { - prog.Package(info.Pkg).Build() + var initpkgs []*ssa.Package + for _, info := range lprog.InitialPackages() { + ssapkg := prog.Package(info.Pkg) + ssapkg.Build() + if info.Pkg.Path() != "runtime" { + initpkgs = append(initpkgs, ssapkg) + } } // Run the interpreter. if *runFlag { prog.Build() - var main *ssa.Package - pkgs := prog.AllPackages() + var mains []*ssa.Package if *testFlag { - // If -test, run all packages' tests. - if len(pkgs) > 0 { - main = prog.CreateTestMainPackage(pkgs...) + // If -test, run the tests. + for _, pkg := range initpkgs { + if main := prog.CreateTestMainPackage(pkg); main != nil { + mains = append(mains, main) + } } - if main == nil { + if mains == nil { return fmt.Errorf("no tests") } } else { - // Otherwise, run main.main. - for _, pkg := range pkgs { - if pkg.Pkg.Name() == "main" { - main = pkg - if main.Func("main") == nil { - return fmt.Errorf("no func main() in main package") - } - break - } - } - if main == nil { + // Otherwise, run the main packages. + mains := ssautil.MainPackages(initpkgs) + if len(mains) == 0 { return fmt.Errorf("no main package") } } @@ -174,7 +172,12 @@ func doMain() error { build.Default.GOARCH, runtime.GOARCH) } - interp.Interpret(main, interpMode, conf.TypeChecker.Sizes, main.Pkg.Path(), args) + for _, main := range mains { + if len(mains) > 1 { + fmt.Fprintf(os.Stderr, "Running: %s\n", main.Pkg.Path()) + } + interp.Interpret(main, interpMode, conf.TypeChecker.Sizes, main.Pkg.Path(), args) + } } return nil } diff --git a/go/pointer/stdlib_test.go b/go/pointer/stdlib_test.go index d3d14ea3..d3ba7216 100644 --- a/go/pointer/stdlib_test.go +++ b/go/pointer/stdlib_test.go @@ -60,20 +60,22 @@ func TestStdlib(t *testing.T) { } // Determine the set of packages/tests to analyze. - var testPkgs []*ssa.Package + var mains []*ssa.Package for _, info := range iprog.InitialPackages() { - testPkgs = append(testPkgs, prog.Package(info.Pkg)) + ssapkg := prog.Package(info.Pkg) + if main := prog.CreateTestMainPackage(ssapkg); main != nil { + mains = append(mains, main) + } } - testmain := prog.CreateTestMainPackage(testPkgs...) - if testmain == nil { - t.Fatal("analysis scope has tests") + if mains == nil { + t.Fatal("no tests found in analysis scope") } // Run the analysis. config := &Config{ Reflection: false, // TODO(adonovan): fix remaining bug in rVCallConstraint, then enable. BuildCallGraph: true, - Mains: []*ssa.Package{testmain}, + Mains: mains, } // TODO(adonovan): add some query values (affects track bits). diff --git a/go/ssa/interp/interp_test.go b/go/ssa/interp/interp_test.go index 70c2f400..d51c2fd9 100644 --- a/go/ssa/interp/interp_test.go +++ b/go/ssa/interp/interp_test.go @@ -222,29 +222,25 @@ func run(t *testing.T, dir, input string, success successPredicate) bool { prog := ssautil.CreateProgram(iprog, ssa.SanityCheckFunctions) prog.Build() + // Find first main or test package among the initial packages. var mainPkg *ssa.Package - var initialPkgs []*ssa.Package for _, info := range iprog.InitialPackages() { if info.Pkg.Path() == "runtime" { continue // not an initial package } p := prog.Package(info.Pkg) - initialPkgs = append(initialPkgs, p) - if mainPkg == nil && p.Func("main") != nil { + if p.Pkg.Name() == "main" && p.Func("main") != nil { mainPkg = p + break + } + + mainPkg = prog.CreateTestMainPackage(p) + if mainPkg != nil { + break } } if mainPkg == nil { - testmainPkg := prog.CreateTestMainPackage(initialPkgs...) - if testmainPkg == nil { - t.Errorf("CreateTestMainPackage(%s) returned nil", mainPkg) - return false - } - if testmainPkg.Func("main") == nil { - t.Errorf("synthetic testmain package has no main") - return false - } - mainPkg = testmainPkg + t.Fatalf("no main or test packages among initial packages: %s", inputs) } var out bytes.Buffer @@ -346,6 +342,23 @@ func TestTestmainPackage(t *testing.T) { return nil } run(t, "testdata"+slash, "a_test.go", success) + + // Run a test with a custom TestMain function and ensure that it + // is executed, and that m.Run runs the tests. + success = func(exitcode int, output string) error { + if exitcode != 0 { + return fmt.Errorf("unexpected failure; output=%s", output) + } + if want := `TestMain start +TestC +PASS +TestMain end +`; output != want { + return fmt.Errorf("output was %q, want %q", output, want) + } + return nil + } + run(t, "testdata"+slash, "c_test.go", success) } // CreateTestMainPackage should return nil if there were no tests. diff --git a/go/ssa/interp/testdata/c_test.go b/go/ssa/interp/testdata/c_test.go new file mode 100644 index 00000000..ad80b910 --- /dev/null +++ b/go/ssa/interp/testdata/c_test.go @@ -0,0 +1,17 @@ +package c_test + +import ( + "os" + "testing" +) + +func TestC(t *testing.T) { + println("TestC") +} + +func TestMain(m *testing.M) { + println("TestMain start") + code := m.Run() + println("TestMain end") + os.Exit(code) +} diff --git a/go/ssa/ssautil/visit.go b/go/ssa/ssautil/visit.go index 6c51f930..3424e8a3 100644 --- a/go/ssa/ssautil/visit.go +++ b/go/ssa/ssautil/visit.go @@ -64,3 +64,16 @@ func (visit *visitor) function(fn *ssa.Function) { } } } + +// MainPackages returns the subset of the specified packages +// named "main" that define a main function. +// The result may include synthetic "testmain" packages. +func MainPackages(pkgs []*ssa.Package) []*ssa.Package { + var mains []*ssa.Package + for _, pkg := range pkgs { + if pkg.Pkg.Name() == "main" && pkg.Func("main") != nil { + mains = append(mains, pkg) + } + } + return mains +} diff --git a/go/ssa/testmain.go b/go/ssa/testmain.go index 48b184a3..a791e4b9 100644 --- a/go/ssa/testmain.go +++ b/go/ssa/testmain.go @@ -12,23 +12,18 @@ package ssa import ( "go/ast" - exact "go/constant" "go/token" "go/types" + "log" "os" - "sort" "strings" ) -// FindTests returns the list of packages that define at least one Test, -// Example or Benchmark function (as defined by "go test"), and the -// lists of all such functions. -// -func FindTests(pkgs []*Package) (testpkgs []*Package, tests, benchmarks, examples []*Function) { - if len(pkgs) == 0 { - return - } - prog := pkgs[0].Prog +// FindTests returns the Test, Benchmark, and Example functions +// (as defined by "go test") defined in the specified package, +// and its TestMain function, if any. +func FindTests(pkg *Package) (tests, benchmarks, examples []*Function, main *Function) { + prog := pkg.Prog // The first two of these may be nil: if the program doesn't import "testing", // it can't contain any tests, but it may yet contain Examples. @@ -36,40 +31,41 @@ func FindTests(pkgs []*Package) (testpkgs []*Package, tests, benchmarks, example var benchmarkSig *types.Signature // func(*testing.B) var exampleSig = types.NewSignature(nil, nil, nil, false) // func() - // Obtain the types from the parameters of testing.Main(). + // Obtain the types from the parameters of testing.MainStart. if testingPkg := prog.ImportedPackage("testing"); testingPkg != nil { - params := testingPkg.Func("Main").Signature.Params() + mainStart := testingPkg.Func("MainStart") + params := mainStart.Signature.Params() testSig = funcField(params.At(1).Type()) benchmarkSig = funcField(params.At(2).Type()) + + // Does the package define this function? + // func TestMain(*testing.M) + if f := pkg.Func("TestMain"); f != nil { + sig := f.Type().(*types.Signature) + starM := mainStart.Signature.Results().At(0).Type() // *testing.M + if sig.Results().Len() == 0 && + sig.Params().Len() == 1 && + types.Identical(sig.Params().At(0).Type(), starM) { + main = f + } + } } - seen := make(map[*Package]bool) - for _, pkg := range pkgs { - if pkg.Prog != prog { - panic("wrong Program") - } + // TODO(adonovan): use a stable order, e.g. lexical. + for _, mem := range pkg.Members { + if f, ok := mem.(*Function); ok && + ast.IsExported(f.Name()) && + strings.HasSuffix(prog.Fset.Position(f.Pos()).Filename, "_test.go") { - // TODO(adonovan): use a stable order, e.g. lexical. - for _, mem := range pkg.Members { - if f, ok := mem.(*Function); ok && - ast.IsExported(f.Name()) && - strings.HasSuffix(prog.Fset.Position(f.Pos()).Filename, "_test.go") { - - switch { - case testSig != nil && isTestSig(f, "Test", testSig): - tests = append(tests, f) - case benchmarkSig != nil && isTestSig(f, "Benchmark", benchmarkSig): - benchmarks = append(benchmarks, f) - case isTestSig(f, "Example", exampleSig): - examples = append(examples, f) - default: - continue - } - - if !seen[pkg] { - seen[pkg] = true - testpkgs = append(testpkgs, pkg) - } + switch { + case testSig != nil && isTestSig(f, "Test", testSig): + tests = append(tests, f) + case benchmarkSig != nil && isTestSig(f, "Benchmark", benchmarkSig): + benchmarks = append(benchmarks, f) + case isTestSig(f, "Example", exampleSig): + examples = append(examples, f) + default: + continue } } } @@ -87,15 +83,22 @@ func isTestSig(f *Function, prefix string, sig *types.Signature) bool { // systems that don't exactly follow 'go test' conventions. var testMainStartBodyHook func(*Function) -// CreateTestMainPackage creates and returns a synthetic "main" -// package that runs all the tests of the supplied packages, similar -// to the one that would be created by the 'go test' tool. +// CreateTestMainPackage creates and returns a synthetic "testmain" +// package for the specified package if it defines tests, benchmarks or +// executable examples, or nil otherwise. The new package is named +// "main" and provides a function named "main" that runs the tests, +// similar to the one that would be created by the 'go test' tool. // -// It returns nil if the program contains no tests. -// -func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { - pkgs, tests, benchmarks, examples := FindTests(pkgs) - if len(pkgs) == 0 { +// Subsequent calls to prog.AllPackages include the new package. +// The package pkg must belong to the program prog. +func (prog *Program) CreateTestMainPackage(pkg *Package) *Package { + if pkg.Prog != prog { + log.Fatal("Package does not belong to Program") + } + + tests, benchmarks, examples, testMainFunc := FindTests(pkg) + + if testMainFunc == nil && tests == nil && benchmarks == nil && examples == nil { return nil } @@ -103,7 +106,7 @@ func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { Prog: prog, Members: make(map[string]Member), values: make(map[types.Object]Value), - Pkg: types.NewPackage("test$main", "main"), + Pkg: types.NewPackage(pkg.Pkg.Path()+"$testmain", "main"), } // Build package's init function. @@ -120,30 +123,18 @@ func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { testMainStartBodyHook(init) } - // Initialize packages to test. - var pkgpaths []string - for _, pkg := range pkgs { - var v Call - v.Call.Value = pkg.init - v.setType(types.NewTuple()) - init.emit(&v) - - pkgpaths = append(pkgpaths, pkg.Pkg.Path()) - } - sort.Strings(pkgpaths) + // Initialize package under test. + var v Call + v.Call.Value = pkg.init + v.setType(types.NewTuple()) + init.emit(&v) init.emit(new(Return)) init.finishBody() testmain.init = init testmain.Pkg.MarkComplete() testmain.Members[init.name] = init - // For debugging convenience, define an unexported const - // that enumerates the packages. - packagesConst := types.NewConst(token.NoPos, testmain.Pkg, "packages", tString, - exact.MakeString(strings.Join(pkgpaths, " "))) - memberFromObject(testmain, packagesConst, nil) - - // Create main *types.Func and *ssa.Function + // Create main *types.Func and *Function mainFunc := types.NewFunc(token.NoPos, testmain.Pkg, "main", new(types.Signature)) memberFromObject(testmain, mainFunc, nil) main := testmain.Func("main") @@ -166,7 +157,12 @@ func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { // tests := []testing.InternalTest{{"TestFoo", TestFoo}, ...} // benchmarks := []testing.InternalBenchmark{...} // examples := []testing.InternalExample{...} - // testing.Main(match, tests, benchmarks, examples) + // if TestMain is defined { + // m := testing.MainStart(match, tests, benchmarks, examples) + // return TestMain(m) + // } else { + // return testing.Main(match, tests, benchmarks, examples) + // } // } matcher := &Function{ @@ -182,23 +178,41 @@ func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { matcher.emit(&Return{Results: []Value{vTrue, nilConst(types.Universe.Lookup("error").Type())}}) matcher.finishBody() - // Emit call: testing.Main(matcher, tests, benchmarks, examples). var c Call - c.Call.Value = testingMain c.Call.Args = []Value{ matcher, testMainSlice(main, tests, testingMainParams.At(1).Type()), testMainSlice(main, benchmarks, testingMainParams.At(2).Type()), testMainSlice(main, examples, testingMainParams.At(3).Type()), } - emitTailCall(main, &c) + if testMainFunc != nil { + // Emit: m := testing.MainStart(matcher, tests, benchmarks, examples). + // (Main and MainStart have the same parameters.) + mainStart := testingPkg.Func("MainStart") + c.Call.Value = mainStart + c.setType(mainStart.Signature.Results().At(0).Type()) // *testing.M + m := main.emit(&c) + + // Emit: return TestMain(m) + var c2 Call + c2.Call.Value = testMainFunc + c2.Call.Args = []Value{m} + emitTailCall(main, &c2) + } else { + // Emit: return testing.Main(matcher, tests, benchmarks, examples) + c.Call.Value = testingMain + emitTailCall(main, &c) + } } else { // The program does not import "testing", but FindTests // returned non-nil, which must mean there were Examples - // but no Tests or Benchmarks. + // but no Test, Benchmark, or TestMain functions. + // We'll simply call them from testmain.main; this will // ensure they don't panic, but will not check any // "Output:" comments. + // (We should not execute an Example that has no + // "Output:" comment, but it's impossible to tell here.) for _, eg := range examples { var c Call c.Call.Value = eg diff --git a/go/ssa/testmain_test.go b/go/ssa/testmain_test.go index 56cb6040..e24b23b9 100644 --- a/go/ssa/testmain_test.go +++ b/go/ssa/testmain_test.go @@ -17,7 +17,7 @@ import ( "golang.org/x/tools/go/ssa/ssautil" ) -func create(t *testing.T, content string) []*ssa.Package { +func create(t *testing.T, content string) *ssa.Package { var conf loader.Config f, err := conf.ParseFile("foo_test.go", content) if err != nil { @@ -25,13 +25,14 @@ func create(t *testing.T, content string) []*ssa.Package { } conf.CreateFromFiles("foo", f) - iprog, err := conf.Load() + lprog, err := conf.Load() if err != nil { t.Fatal(err) } // We needn't call Build. - return ssautil.CreateProgram(iprog, ssa.SanityCheckFunctions).AllPackages() + foo := lprog.Package("foo").Pkg + return ssautil.CreateProgram(lprog, ssa.SanityCheckFunctions).Package(foo) } func TestFindTests(t *testing.T) { @@ -74,8 +75,8 @@ func ExampleD(t *testing.T) {} func exampleE() int { return 0 } func (T) Example() {} ` - pkgs := create(t, test) - _, tests, benchmarks, examples := ssa.FindTests(pkgs) + pkg := create(t, test) + tests, benchmarks, examples, _ := ssa.FindTests(pkg) sort.Sort(funcsByPos(tests)) if got, want := fmt.Sprint(tests), "[foo.Test foo.TestA foo.TestB]"; got != want { @@ -102,8 +103,8 @@ package foo func Example() {} func ExampleA() {} ` - pkgs := create(t, test) - _, tests, benchmarks, examples := ssa.FindTests(pkgs) + pkg := create(t, test) + tests, benchmarks, examples, _ := ssa.FindTests(pkg) if len(tests) > 0 { t.Errorf("FindTests.tests = %s, want none", tests) } diff --git a/godoc/analysis/analysis.go b/godoc/analysis/analysis.go index 633428ca..5ee03425 100644 --- a/godoc/analysis/analysis.go +++ b/godoc/analysis/analysis.go @@ -47,7 +47,6 @@ package analysis // import "golang.org/x/tools/godoc/analysis" import ( "fmt" "go/build" - exact "go/constant" "go/scanner" "go/token" "go/types" @@ -396,21 +395,12 @@ func Run(pta bool, result *Result) { // Only the transitively error-free packages are used. prog := ssautil.CreateProgram(iprog, ssa.GlobalDebug) - // Compute the set of main packages, including testmain. - allPackages := prog.AllPackages() - var mainPkgs []*ssa.Package - if testmain := prog.CreateTestMainPackage(allPackages...); testmain != nil { - mainPkgs = append(mainPkgs, testmain) - if p := testmain.Const("packages"); p != nil { - log.Printf("Tested packages: %v", exact.StringVal(p.Value.Value)) + // Create a "testmain" package for each package with tests. + for _, pkg := range prog.AllPackages() { + if testmain := prog.CreateTestMainPackage(pkg); testmain != nil { + log.Printf("Adding tests for %s", pkg.Pkg.Path()) } } - for _, pkg := range allPackages { - if pkg.Pkg.Name() == "main" && pkg.Func("main") != nil { - mainPkgs = append(mainPkgs, pkg) - } - } - log.Print("Transitively error-free main packages: ", mainPkgs) // Build SSA code for bodies of all functions in the whole program. result.setStatusf("Constructing SSA form...") @@ -505,6 +495,8 @@ func Run(pta bool, result *Result) { result.setStatusf("Type analysis complete.") if pta { + mainPkgs := ssautil.MainPackages(prog.AllPackages()) + log.Print("Transitively error-free main packages: ", mainPkgs) a.pointer(mainPkgs) } }