diff --git a/_pytest/capture.py b/_pytest/capture.py index 778b1396b..b3b7dee7e 100644 --- a/_pytest/capture.py +++ b/_pytest/capture.py @@ -333,6 +333,8 @@ class MultiCapture(object): return (self.out.snap() if self.out is not None else "", self.err.snap() if self.err is not None else "") +class NoCapture: + __init__ = start = done = lambda *args: None class FDCapture: """ Capture IO to/from a given os-level filedescriptor. """ @@ -340,36 +342,38 @@ class FDCapture: def __init__(self, targetfd, tmpfile=None): self.targetfd = targetfd try: - self._savefd = os.dup(self.targetfd) + self.targetfd_save = os.dup(self.targetfd) except OSError: self.start = lambda: None self.done = lambda: None else: - if tmpfile is None: - if targetfd == 0: - tmpfile = open(os.devnull, "r") - else: + if targetfd == 0: + assert not tmpfile, "cannot set tmpfile with stdin" + tmpfile = open(os.devnull, "r") + self.syscapture = SysCapture(targetfd) + else: + if tmpfile is None: f = TemporaryFile() with f: tmpfile = safe_text_dupfile(f, mode="wb+") + if targetfd in patchsysdict: + self.syscapture = SysCapture(targetfd, tmpfile) + else: + self.syscapture = NoCapture() self.tmpfile = tmpfile - if targetfd in patchsysdict: - self._oldsys = getattr(sys, patchsysdict[targetfd]) + self.tmpfile_fd = tmpfile.fileno() def __repr__(self): - return "" % (self.targetfd, self._savefd) + return "" % (self.targetfd, self.targetfd_save) def start(self): """ Start capturing on targetfd using memorized tmpfile. """ try: - os.fstat(self._savefd) - except OSError: + os.fstat(self.targetfd_save) + except (AttributeError, OSError): raise ValueError("saved filedescriptor not valid anymore") - targetfd = self.targetfd - os.dup2(self.tmpfile.fileno(), targetfd) - if hasattr(self, '_oldsys'): - subst = self.tmpfile if targetfd != 0 else DontReadFromInput() - setattr(sys, patchsysdict[targetfd], subst) + os.dup2(self.tmpfile_fd, self.targetfd) + self.syscapture.start() def snap(self): f = self.tmpfile @@ -386,28 +390,38 @@ class FDCapture: def done(self): """ stop capturing, restore streams, return original capture file, seeked to position zero. """ - os.dup2(self._savefd, self.targetfd) - os.close(self._savefd) - if hasattr(self, '_oldsys'): - setattr(sys, patchsysdict[self.targetfd], self._oldsys) + targetfd_save = self.__dict__.pop("targetfd_save") + os.dup2(targetfd_save, self.targetfd) + os.close(targetfd_save) + self.syscapture.done() self.tmpfile.close() + def suspend(self): + self.syscapture.suspend() + os.dup2(self.targetfd_save, self.targetfd) + + def resume(self): + self.syscapture.resume() + os.dup2(self.tmpfile_fd, self.targetfd) + def writeorg(self, data): """ write to original file descriptor. """ if py.builtin._istext(data): data = data.encode("utf8") # XXX use encoding of original stream - os.write(self._savefd, data) + os.write(self.targetfd_save, data) class SysCapture: - def __init__(self, fd): + def __init__(self, fd, tmpfile=None): name = patchsysdict[fd] self._old = getattr(sys, name) self.name = name - if name == "stdin": - self.tmpfile = DontReadFromInput() - else: - self.tmpfile = TextIO() + if tmpfile is None: + if name == "stdin": + tmpfile = DontReadFromInput() + else: + tmpfile = TextIO() + self.tmpfile = tmpfile def start(self): setattr(sys, self.name, self.tmpfile) @@ -421,8 +435,15 @@ class SysCapture: def done(self): setattr(sys, self.name, self._old) + del self._old self.tmpfile.close() + def suspend(self): + setattr(sys, self.name, self._old) + + def resume(self): + setattr(sys, self.name, self.tmpfile) + def writeorg(self, data): self._old.write(data) self._old.flush() diff --git a/testing/test_capture.py b/testing/test_capture.py index 33f538890..fb2ff75d7 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -726,15 +726,11 @@ class TestFDCapture: assert s == "hello\n" def test_stdin(self, tmpfile): - tmpfile.write(tobytes("3")) - tmpfile.seek(0) - cap = capture.FDCapture(0, tmpfile) + cap = capture.FDCapture(0) cap.start() - # check with os.read() directly instead of raw_input(), because - # sys.stdin itself may be redirected (as pytest now does by default) x = os.read(0, 100).strip() cap.done() - assert x == tobytes("3") + assert x == tobytes('') def test_writeorg(self, tmpfile): data1, data2 = tobytes("foo"), tobytes("bar") @@ -751,7 +747,37 @@ class TestFDCapture: stmp = open(tmpfile.name, 'rb').read() assert stmp == data2 + def test_simple_resume_suspend(self, tmpfile): + with saved_fd(1): + cap = capture.FDCapture(1) + cap.start() + data = tobytes("hello") + os.write(1, data) + sys.stdout.write("whatever") + s = cap.snap() + assert s == "hellowhatever" + cap.suspend() + os.write(1, tobytes("world")) + sys.stdout.write("qlwkej") + assert not cap.snap() + cap.resume() + os.write(1, tobytes("but now")) + sys.stdout.write(" yes\n") + s = cap.snap() + assert s == "but now yes\n" + cap.suspend() + cap.done() + pytest.raises(AttributeError, cap.suspend) +@contextlib.contextmanager +def saved_fd(fd): + new_fd = os.dup(fd) + try: + yield + finally: + os.dup2(new_fd, fd) + + class TestStdCapture: captureclass = staticmethod(StdCapture)