go/packages: deduplicate roots

refine needs to match packages back to roots, and it gets confused if
the same root appears twice. Deduplicate roots the same way we
deduplicate packages.

Introduce a new wrapper around driverResponse to make this easier, and
pass that around instead of the addPackage callback.

Fixes golang/go#29297

Change-Id: I49ea37155c507af136391b9eb55a83b6dedfcc14
Reviewed-on: https://go-review.googlesource.com/c/155020
Run-TryBot: Heschi Kreinick <heschi@google.com>
Reviewed-by: Michael Matloob <matloob@golang.org>
This commit is contained in:
Heschi Kreinick 2018-12-19 16:23:27 -05:00
parent 92cdcd90bf
commit 3571f65a7b
2 changed files with 87 additions and 49 deletions

View File

@ -34,6 +34,42 @@ type goTooOldError struct {
error error
} }
// responseDeduper wraps a driverResponse, deduplicating its contents.
type responseDeduper struct {
seenRoots map[string]bool
seenPackages map[string]*Package
dr *driverResponse
}
// init fills in r with a driverResponse.
func (r *responseDeduper) init(dr *driverResponse) {
r.dr = dr
r.seenRoots = map[string]bool{}
r.seenPackages = map[string]*Package{}
for _, pkg := range dr.Packages {
r.seenPackages[pkg.ID] = pkg
}
for _, root := range dr.Roots {
r.seenRoots[root] = true
}
}
func (r *responseDeduper) addPackage(p *Package) {
if r.seenPackages[p.ID] != nil {
return
}
r.seenPackages[p.ID] = p
r.dr.Packages = append(r.dr.Packages, p)
}
func (r *responseDeduper) addRoot(id string) {
if r.seenRoots[id] {
return
}
r.seenRoots[id] = true
r.dr.Roots = append(r.dr.Roots, id)
}
// goListDriver uses the go list command to interpret the patterns and produce // goListDriver uses the go list command to interpret the patterns and produce
// the build system package structure. // the build system package structure.
// See driver for more details. // See driver for more details.
@ -99,17 +135,18 @@ extractQueries:
return response, err return response, err
} }
var response *driverResponse response := &responseDeduper{}
var err error var err error
// see if we have any patterns to pass through to go list. // see if we have any patterns to pass through to go list.
if len(restPatterns) > 0 { if len(restPatterns) > 0 {
response, err = listfunc(cfg, restPatterns...) dr, err := listfunc(cfg, restPatterns...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
response.init(dr)
} else { } else {
response = &driverResponse{} response.init(&driverResponse{})
} }
sizeswg.Wait() sizeswg.Wait()
@ -117,38 +154,23 @@ extractQueries:
return nil, sizeserr return nil, sizeserr
} }
// types.SizesFor always returns nil or a *types.StdSizes // types.SizesFor always returns nil or a *types.StdSizes
response.Sizes, _ = sizes.(*types.StdSizes) response.dr.Sizes, _ = sizes.(*types.StdSizes)
seenPkgs := make(map[string]*Package) // for deduplication. different containing queries could produce same packages
for _, pkg := range response.Packages {
seenPkgs[pkg.ID] = pkg
}
addPkg := func(p *Package) {
if _, ok := seenPkgs[p.ID]; ok {
return
}
seenPkgs[p.ID] = p
response.Packages = append(response.Packages, p)
}
var containsCandidates []string var containsCandidates []string
if len(containFiles) != 0 { if len(containFiles) != 0 {
containsCandidates, err = runContainsQueries(cfg, listfunc, isFallback, addPkg, containFiles) if err := runContainsQueries(cfg, listfunc, isFallback, response, containFiles); err != nil {
if err != nil {
return nil, err return nil, err
} }
} }
if len(packagesNamed) != 0 { if len(packagesNamed) != 0 {
namedResults, err := runNamedQueries(cfg, listfunc, addPkg, packagesNamed) if err := runNamedQueries(cfg, listfunc, response, packagesNamed); err != nil {
if err != nil {
return nil, err return nil, err
} }
response.Roots = append(response.Roots, namedResults...)
} }
modifiedPkgs, needPkgs, err := processGolistOverlay(cfg, response) modifiedPkgs, needPkgs, err := processGolistOverlay(cfg, response.dr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -158,7 +180,7 @@ extractQueries:
} }
if len(needPkgs) > 0 { if len(needPkgs) > 0 {
addNeededOverlayPackages(cfg, listfunc, addPkg, needPkgs) addNeededOverlayPackages(cfg, listfunc, response, needPkgs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -166,33 +188,32 @@ extractQueries:
// Check candidate packages for containFiles. // Check candidate packages for containFiles.
if len(containFiles) > 0 { if len(containFiles) > 0 {
for _, id := range containsCandidates { for _, id := range containsCandidates {
pkg := seenPkgs[id] pkg := response.seenPackages[id]
for _, f := range containFiles { for _, f := range containFiles {
for _, g := range pkg.GoFiles { for _, g := range pkg.GoFiles {
if sameFile(f, g) { if sameFile(f, g) {
response.Roots = append(response.Roots, id) response.addRoot(id)
} }
} }
} }
} }
} }
return response, nil return response.dr, nil
} }
func addNeededOverlayPackages(cfg *Config, driver driver, addPkg func(*Package), pkgs []string) error { func addNeededOverlayPackages(cfg *Config, driver driver, response *responseDeduper, pkgs []string) error {
response, err := driver(cfg, pkgs...) dr, err := driver(cfg, pkgs...)
if err != nil { if err != nil {
return err return err
} }
for _, pkg := range response.Packages { for _, pkg := range dr.Packages {
addPkg(pkg) response.addPackage(pkg)
} }
return nil return nil
} }
func runContainsQueries(cfg *Config, driver driver, isFallback bool, addPkg func(*Package), queries []string) ([]string, error) { func runContainsQueries(cfg *Config, driver driver, isFallback bool, response *responseDeduper, queries []string) error {
var results []string
for _, query := range queries { for _, query := range queries {
// TODO(matloob): Do only one query per directory. // TODO(matloob): Do only one query per directory.
fdir := filepath.Dir(query) fdir := filepath.Dir(query)
@ -200,7 +221,7 @@ func runContainsQueries(cfg *Config, driver driver, isFallback bool, addPkg func
// not a package path. // not a package path.
pattern, err := filepath.Abs(fdir) pattern, err := filepath.Abs(fdir)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not determine absolute path of file= query path %q: %v", query, err) return fmt.Errorf("could not determine absolute path of file= query path %q: %v", query, err)
} }
if isFallback { if isFallback {
pattern = "." pattern = "."
@ -209,7 +230,7 @@ func runContainsQueries(cfg *Config, driver driver, isFallback bool, addPkg func
dirResponse, err := driver(cfg, pattern) dirResponse, err := driver(cfg, pattern)
if err != nil { if err != nil {
return nil, err return err
} }
isRoot := make(map[string]bool, len(dirResponse.Roots)) isRoot := make(map[string]bool, len(dirResponse.Roots))
for _, root := range dirResponse.Roots { for _, root := range dirResponse.Roots {
@ -220,34 +241,34 @@ func runContainsQueries(cfg *Config, driver driver, isFallback bool, addPkg func
// We don't bother to filter packages that will be dropped by the changes of roots, // We don't bother to filter packages that will be dropped by the changes of roots,
// that will happen anyway during graph construction outside this function. // that will happen anyway during graph construction outside this function.
// Over-reporting packages is not a problem. // Over-reporting packages is not a problem.
addPkg(pkg) response.addPackage(pkg)
// if the package was not a root one, it cannot have the file // if the package was not a root one, it cannot have the file
if !isRoot[pkg.ID] { if !isRoot[pkg.ID] {
continue continue
} }
for _, pkgFile := range pkg.GoFiles { for _, pkgFile := range pkg.GoFiles {
if filepath.Base(query) == filepath.Base(pkgFile) { if filepath.Base(query) == filepath.Base(pkgFile) {
results = append(results, pkg.ID) response.addRoot(pkg.ID)
break break
} }
} }
} }
} }
return results, nil return nil
} }
// modCacheRegexp splits a path in a module cache into module, module version, and package. // modCacheRegexp splits a path in a module cache into module, module version, and package.
var modCacheRegexp = regexp.MustCompile(`(.*)@([^/\\]*)(.*)`) var modCacheRegexp = regexp.MustCompile(`(.*)@([^/\\]*)(.*)`)
func runNamedQueries(cfg *Config, driver driver, addPkg func(*Package), queries []string) ([]string, error) { func runNamedQueries(cfg *Config, driver driver, response *responseDeduper, queries []string) error {
// calling `go env` isn't free; bail out if there's nothing to do. // calling `go env` isn't free; bail out if there's nothing to do.
if len(queries) == 0 { if len(queries) == 0 {
return nil, nil return nil
} }
// Determine which directories are relevant to scan. // Determine which directories are relevant to scan.
roots, modRoot, err := roots(cfg) roots, modRoot, err := roots(cfg)
if err != nil { if err != nil {
return nil, err return err
} }
// Scan the selected directories. Simple matches, from GOPATH/GOROOT // Scan the selected directories. Simple matches, from GOPATH/GOROOT
@ -306,13 +327,12 @@ func runNamedQueries(cfg *Config, driver driver, addPkg func(*Package), queries
} }
} }
var results []string
addResponse := func(r *driverResponse) { addResponse := func(r *driverResponse) {
for _, pkg := range r.Packages { for _, pkg := range r.Packages {
addPkg(pkg) response.addPackage(pkg)
for _, name := range queries { for _, name := range queries {
if pkg.Name == name { if pkg.Name == name {
results = append(results, pkg.ID) response.addRoot(pkg.ID)
break break
} }
} }
@ -322,7 +342,7 @@ func runNamedQueries(cfg *Config, driver driver, addPkg func(*Package), queries
if len(simpleMatches) != 0 { if len(simpleMatches) != 0 {
resp, err := driver(cfg, simpleMatches...) resp, err := driver(cfg, simpleMatches...)
if err != nil { if err != nil {
return nil, err return err
} }
addResponse(resp) addResponse(resp)
} }
@ -371,23 +391,23 @@ func runNamedQueries(cfg *Config, driver driver, addPkg func(*Package), queries
var err error var err error
tmpCfg.Dir, err = ioutil.TempDir("", "gopackages-modquery") tmpCfg.Dir, err = ioutil.TempDir("", "gopackages-modquery")
if err != nil { if err != nil {
return nil, err return err
} }
defer os.RemoveAll(tmpCfg.Dir) defer os.RemoveAll(tmpCfg.Dir)
if err := ioutil.WriteFile(filepath.Join(tmpCfg.Dir, "go.mod"), gomod.Bytes(), 0777); err != nil { if err := ioutil.WriteFile(filepath.Join(tmpCfg.Dir, "go.mod"), gomod.Bytes(), 0777); err != nil {
return nil, fmt.Errorf("writing go.mod for module cache query: %v", err) return fmt.Errorf("writing go.mod for module cache query: %v", err)
} }
// Run the query, using the import paths calculated from the matches above. // Run the query, using the import paths calculated from the matches above.
resp, err := driver(&tmpCfg, imports...) resp, err := driver(&tmpCfg, imports...)
if err != nil { if err != nil {
return nil, fmt.Errorf("querying module cache matches: %v", err) return fmt.Errorf("querying module cache matches: %v", err)
} }
addResponse(resp) addResponse(resp)
} }
return results, nil return nil
} }
func getSizes(cfg *Config) (types.Sizes, error) { func getSizes(cfg *Config) (types.Sizes, error) {

View File

@ -1116,7 +1116,6 @@ func testSizes(t *testing.T, exporter packagestest.Exporter) {
t.Errorf("for GOARCH=%s, got word size %d, want %d", arch, gotWordSize, wantWordSize) t.Errorf("for GOARCH=%s, got word size %d, want %d", arch, gotWordSize, wantWordSize)
} }
} }
} }
// TestContains_FallbackSticks ensures that when there are both contains and non-contains queries // TestContains_FallbackSticks ensures that when there are both contains and non-contains queries
@ -1259,6 +1258,25 @@ func TestName_ModulesDedup(t *testing.T) {
t.Errorf("didn't find v2.0.2 of pkg in Load results: %v", initial) t.Errorf("didn't find v2.0.2 of pkg in Load results: %v", initial)
} }
// Test that Load doesn't get confused when two different patterns match the same package. See #29297.
func TestRedundantQueries(t *testing.T) { packagestest.TestAll(t, testRedundantQueries) }
func testRedundantQueries(t *testing.T, exporter packagestest.Exporter) {
exported := packagestest.Export(t, exporter, []packagestest.Module{{
Name: "golang.org/fake",
Files: map[string]interface{}{
"a/a.go": `package a;`,
}}})
defer exported.Cleanup()
initial, err := packages.Load(exported.Config, "errors", "name=errors")
if err != nil {
t.Fatal(err)
}
if len(initial) != 1 || initial[0].Name != "errors" {
t.Fatalf(`Load("errors", "name=errors") = %v, wanted just the errors package`, initial)
}
}
func TestJSON(t *testing.T) { packagestest.TestAll(t, testJSON) } func TestJSON(t *testing.T) { packagestest.TestAll(t, testJSON) }
func testJSON(t *testing.T, exporter packagestest.Exporter) { func testJSON(t *testing.T, exporter packagestest.Exporter) {
//TODO: add in some errors //TODO: add in some errors