diff --git a/imports/fix.go b/imports/fix.go index ab29ce9c..9cdbbe91 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -378,11 +378,6 @@ func (s byImportPathShortLength) Less(i, j int) bool { } func (s byImportPathShortLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -var visitedSymlinks struct { - sync.Mutex - m map[string]struct{} -} - // guarded by populateIgnoreOnce; populates ignoredDirs. func populateIgnore() { for _, srcDir := range build.Default.SrcDirs() { @@ -459,23 +454,26 @@ func shouldTraverse(dir string, fi os.FileInfo) bool { if skipDir(ts) { return false } + // Check for symlink loops by statting each directory component + // and seeing if any are the same file as ts. + for { + parent := filepath.Dir(path) + if parent == path { + // Made it to the root without seeing a cycle. + // Use this symlink. + return true + } + parentInfo, err := os.Stat(parent) + if err != nil { + return false + } + if os.SameFile(ts, parentInfo) { + // Cycle. Don't traverse. + return false + } + path = parent + } - realParent, err := filepath.EvalSymlinks(dir) - if err != nil { - fmt.Fprint(os.Stderr, err) - return false - } - realPath := filepath.Join(realParent, fi.Name()) - visitedSymlinks.Lock() - defer visitedSymlinks.Unlock() - if visitedSymlinks.m == nil { - visitedSymlinks.m = make(map[string]struct{}) - } - if _, ok := visitedSymlinks.m[realPath]; ok { - return false - } - visitedSymlinks.m[realPath] = struct{}{} - return true } var testHookScanDir = func(dir string) {} diff --git a/imports/fix_test.go b/imports/fix_test.go index 4a77ef5b..048b9c3a 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -863,6 +863,15 @@ func TestImportSymlinks(t *testing.T) { } defer os.RemoveAll(newGoPath) + // Create: + // $GOPATH/target/ + // $GOPATH/target/f.go // package mypkg\nvar Foo = 123\n + // $GOPATH/src/x/ + // $GOPATH/src/x/mypkg => $GOPATH/target // symlink + // $GOPATH/src/x/apkg => $GOPATH/src/x // symlink loop + // Test: + // $GOPATH/src/myotherpkg/toformat.go referencing mypkg.Foo + targetPath := newGoPath + "/target" if err := os.MkdirAll(targetPath, 0755); err != nil { t.Fatal(err) @@ -1060,7 +1069,6 @@ func withEmptyGoPath(fn func()) { oldGOPATH := build.Default.GOPATH oldGOROOT := build.Default.GOROOT build.Default.GOPATH = "" - visitedSymlinks.m = nil testHookScanDir = func(string) {} testMu.Unlock() @@ -1673,3 +1681,74 @@ func TestPkgIsCandidate(t *testing.T) { } } } + +func TestShouldTraverse(t *testing.T) { + switch runtime.GOOS { + case "windows", "plan9": + t.Skipf("skipping symlink-requiring test on %s", runtime.GOOS) + } + + dir, err := ioutil.TempDir("", "goimports-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + // Note: mapToDir prepends "src" to each element, since + // mapToDir was made for creating GOPATHs. + if err := mapToDir(dir, map[string]string{ + "foo/foo2/file.txt": "", + "foo/foo2/link-to-src": "LINK:" + dir + "/src", + "foo/foo2/link-to-src-foo": "LINK:" + dir + "/src/foo", + "foo/foo2/link-to-dot": "LINK:.", + "bar/bar2/file.txt": "", + "bar/bar2/link-to-src-foo": "LINK:" + dir + "/src/foo", + + "a/b/c": "LINK:" + dir + "/src/a/d", + "a/d/e": "LINK:" + dir + "/src/a/b", + }); err != nil { + t.Fatal(err) + } + tests := []struct { + dir string + file string + want bool + }{ + { + dir: dir + "/src/foo/foo2", + file: "link-to-src-foo", + want: false, // loop + }, + { + dir: dir + "/src/foo/foo2", + file: "link-to-src", + want: false, // loop + }, + { + dir: dir + "/src/foo/foo2", + file: "link-to-dot", + want: false, // loop + }, + { + dir: dir + "/src/bar/bar2", + file: "link-to-src-foo", + want: true, // not a loop + }, + { + dir: dir + "/src/a/b/c", + file: "e", + want: false, // loop: "e" is the same as "b". + }, + } + for i, tt := range tests { + fi, err := os.Stat(filepath.Join(tt.dir, tt.file)) + if err != nil { + t.Errorf("%d. Stat = %v", i, err) + continue + } + got := shouldTraverse(tt.dir, fi) + if got != tt.want { + t.Errorf("%d. shouldTraverse(%q, %q) = %v; want %v", i, tt.dir, tt.file, got, tt.want) + } + } +}