diff --git a/go/loader/loader.go b/go/loader/loader.go index 16ccd246..c49ea6d6 100644 --- a/go/loader/loader.go +++ b/go/loader/loader.go @@ -268,6 +268,11 @@ type Config struct { // to startup, or by setting Build.CgoEnabled=false. Build *build.Context + // The current directory, used for resolving relative package + // references such as "./go/loader". If empty, os.Getwd will be + // used instead. + Cwd string + // If DisplayPath is non-nil, it is used to transform each // file name obtained from Build.Import(). This can be used // to prevent a virtualized build.Config's file names from @@ -640,8 +645,24 @@ func (conf *Config) Load() (*Program, error) { conf.TypeChecker.Error = func(e error) { fmt.Fprintln(os.Stderr, e) } } + // Set default working directory for relative package references. + if conf.Cwd == "" { + var err error + conf.Cwd, err = os.Getwd() + if err != nil { + return nil, err + } + } + + // Install default FindPackage hook using go/build logic. if conf.FindPackage == nil { - conf.FindPackage = defaultFindPackage + conf.FindPackage = func(ctxt *build.Context, path string) (*build.Package, error) { + bp, err := ctxt.Import(path, conf.Cwd, 0) + if _, ok := err.(*build.NoGoError); ok { + return bp, nil // empty directory is not an error + } + return bp, err + } } prog := &Program{ @@ -843,17 +864,6 @@ func (conf *Config) build() *build.Context { return &build.Default } -// defaultFindPackage locates the specified (possibly empty) package -// using go/build logic. It returns an error if not found. -func defaultFindPackage(ctxt *build.Context, path string) (*build.Package, error) { - // Import(srcDir="") disables local imports, e.g. import "./foo". - bp, err := ctxt.Import(path, "", 0) - if _, ok := err.(*build.NoGoError); ok { - return bp, nil // empty directory is not an error - } - return bp, err -} - // parsePackageFiles enumerates the files belonging to package path, // then loads, parses and returns them, plus a list of I/O or parse // errors that were encountered. @@ -1084,7 +1094,7 @@ func (imp *importer) loadFromSource(path string) (*PackageInfo, error) { if err != nil { return nil, err // package not found } - info := imp.newPackageInfo(path) + info := imp.newPackageInfo(bp.ImportPath) info.Importable = true files, errs := imp.conf.parsePackageFiles(bp, 'g') for _, err := range errs { diff --git a/go/loader/loader_test.go b/go/loader/loader_test.go index aa8c15be..0b42e398 100644 --- a/go/loader/loader_test.go +++ b/go/loader/loader_test.go @@ -356,6 +356,38 @@ func TestLoad_BadDependency_AllowErrors(t *testing.T) { } } +func TestCwd(t *testing.T) { + ctxt := fakeContext(map[string]string{"one/two/three": `package three`}) + for _, test := range []struct { + cwd, arg, want string + }{ + {cwd: "/go/src/one", arg: "./two/three", want: "one/two/three"}, + {cwd: "/go/src/one", arg: "../one/two/three", want: "one/two/three"}, + {cwd: "/go/src/one", arg: "one/two/three", want: "one/two/three"}, + {cwd: "/go/src/one/two/three", arg: ".", want: "one/two/three"}, + {cwd: "/go/src/one", arg: "two/three", want: ""}, + } { + conf := loader.Config{ + Cwd: test.cwd, + Build: ctxt, + } + conf.Import(test.arg) + + var got string + prog, err := conf.Load() + if prog != nil { + got = imported(prog) + } + if got != test.want { + t.Errorf("Load(%s) from %s: Imported = %s, want %s", + test.arg, test.cwd, got, test.want) + if err != nil { + t.Errorf("Load failed: %v", err) + } + } + } +} + // TODO(adonovan): more Load tests: // // failures: