diff --git a/go/packages/packages.go b/go/packages/packages.go index d35d0104..1a4c1da5 100644 --- a/go/packages/packages.go +++ b/go/packages/packages.go @@ -365,6 +365,9 @@ func newLoader(cfg *Config) *loader { if cfg != nil { ld.Config = *cfg } + if ld.Config.Env == nil { + ld.Config.Env = os.Environ() + } if ld.Context == nil { ld.Context = context.Background() } diff --git a/go/packages/packages_test.go b/go/packages/packages_test.go index b2f6a076..bdf918a1 100644 --- a/go/packages/packages_test.go +++ b/go/packages/packages_test.go @@ -1319,6 +1319,88 @@ func TestJSON(t *testing.T) { } } +func TestConfigDefaultEnv(t *testing.T) { + if runtime.GOOS == "windows" { + // TODO(jayconrod): write an equivalent batch script for windows. + // Hint: "type" can be used to read a file to stdout. + t.Skip("test requires sh") + } + tmp, cleanup := makeTree(t, map[string]string{ + "bin/gopackagesdriver": `#!/bin/sh + +cat - <<'EOF' +{ + "Roots": ["gopackagesdriver"], + "Packages": [{"ID": "gopackagesdriver", "Name": "gopackagesdriver"}] +} +EOF +`, + "src/golist/golist.go": "package golist", + }) + defer cleanup() + if err := os.Chmod(filepath.Join(tmp, "bin", "gopackagesdriver"), 0755); err != nil { + t.Fatal(err) + } + + path, ok := os.LookupEnv("PATH") + var pathWithDriver string + if ok { + pathWithDriver = filepath.Join(tmp, "bin") + string(os.PathListSeparator) + path + } else { + pathWithDriver = filepath.Join(tmp, "bin") + } + + for _, test := range []struct { + desc string + env []string + wantIDs string + }{ + { + desc: "driver_off", + env: []string{"PATH", pathWithDriver, "GOPATH", tmp, "GOPACKAGESDRIVER", "off"}, + wantIDs: "[golist]", + }, { + desc: "driver_unset", + env: []string{"PATH", pathWithDriver, "GOPATH", "", "GOPACKAGESDRIVER", ""}, + wantIDs: "[gopackagesdriver]", + }, { + desc: "driver_set", + env: []string{"GOPACKAGESDRIVER", filepath.Join(tmp, "bin", "gopackagesdriver")}, + wantIDs: "[gopackagesdriver]", + }, + } { + t.Run(test.desc, func(t *testing.T) { + for i := 0; i < len(test.env); i += 2 { + key, value := test.env[i], test.env[i+1] + old, ok := os.LookupEnv(key) + if value == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, value) + } + if ok { + defer os.Setenv(key, old) + } else { + defer os.Unsetenv(key) + } + } + + pkgs, err := packages.Load(nil, "golist") + if err != nil { + t.Fatal(err) + } + + gotIds := make([]string, len(pkgs)) + for i, pkg := range pkgs { + gotIds[i] = pkg.ID + } + if fmt.Sprint(pkgs) != test.wantIDs { + t.Errorf("got %v; want %v", gotIds, test.wantIDs) + } + }) + } +} + func errorMessages(errors []error) []string { var msgs []string for _, err := range errors {