simplify and unify FDCapture API and usage:

* FDCapture now takes care through the 'patchsys' option to
  also set sys.stdin/out/err - setfiles/unsetfiles methods removed -
  i doubt anybody uses this outside of py.test's own old usage.
* stdin also goes through FDCapture now.

--HG--
branch : trunk
This commit is contained in:
holger krekel 2010-05-18 20:03:44 +02:00
parent c790288a5f
commit 9f5e6f9761
2 changed files with 27 additions and 37 deletions

View File

@ -26,23 +26,26 @@ except ImportError:
raise TypeError("not a byte value: %r" %(data,)) raise TypeError("not a byte value: %r" %(data,))
StringIO.write(self, data) StringIO.write(self, data)
patchsysdict = {0: 'stdin', 1: 'stdout', 2: 'stderr'}
class FDCapture: class FDCapture:
""" Capture IO to/from a given os-level filedescriptor. """ """ Capture IO to/from a given os-level filedescriptor. """
def __init__(self, targetfd, tmpfile=None, now=True): def __init__(self, targetfd, tmpfile=None, now=True, patchsys=False):
""" save targetfd descriptor, and open a new """ save targetfd descriptor, and open a new
temporary file there. If no tmpfile is temporary file there. If no tmpfile is
specified a tempfile.Tempfile() will be opened specified a tempfile.Tempfile() will be opened
in text mode. in text mode.
""" """
self.targetfd = targetfd self.targetfd = targetfd
self._patched = []
if tmpfile is None: if tmpfile is None:
f = tempfile.TemporaryFile('wb+') f = tempfile.TemporaryFile('wb+')
tmpfile = dupfile(f, encoding="UTF-8") tmpfile = dupfile(f, encoding="UTF-8")
f.close() f.close()
self.tmpfile = tmpfile self.tmpfile = tmpfile
self._savefd = os.dup(self.targetfd) self._savefd = os.dup(self.targetfd)
if patchsys:
self._oldsys = getattr(sys, patchsysdict[targetfd])
if now: if now:
self.start() self.start()
@ -53,20 +56,8 @@ class FDCapture:
raise ValueError("saved filedescriptor not valid, " raise ValueError("saved filedescriptor not valid, "
"did you call start() twice?") "did you call start() twice?")
os.dup2(self.tmpfile.fileno(), self.targetfd) os.dup2(self.tmpfile.fileno(), self.targetfd)
if hasattr(self, '_oldsys'):
def setasfile(self, name, module=sys): setattr(sys, patchsysdict[self.targetfd], self.tmpfile)
""" patch <module>.<name> to self.tmpfile
"""
key = (module, name)
self._patched.append((key, getattr(module, name)))
setattr(module, name, self.tmpfile)
def unsetfiles(self):
""" unpatch all patched items
"""
while self._patched:
(module, name), value = self._patched.pop()
setattr(module, name, value)
def done(self): def done(self):
""" unpatch and clean up, returns the self.tmpfile (file object) """ unpatch and clean up, returns the self.tmpfile (file object)
@ -74,7 +65,8 @@ class FDCapture:
os.dup2(self._savefd, self.targetfd) os.dup2(self._savefd, self.targetfd)
os.close(self._savefd) os.close(self._savefd)
self.tmpfile.seek(0) self.tmpfile.seek(0)
self.unsetfiles() if hasattr(self, '_oldsys'):
setattr(sys, patchsysdict[self.targetfd], self._oldsys)
return self.tmpfile return self.tmpfile
def writeorg(self, data): def writeorg(self, data):
@ -182,7 +174,6 @@ class StdCaptureFD(Capture):
in_=True, patchsys=True, now=True): in_=True, patchsys=True, now=True):
self._options = locals() self._options = locals()
self._save() self._save()
self.patchsys = patchsys
if now: if now:
self.startall() self.startall()
@ -191,10 +182,17 @@ class StdCaptureFD(Capture):
out = self._options['out'] out = self._options['out']
err = self._options['err'] err = self._options['err']
mixed = self._options['mixed'] mixed = self._options['mixed']
self.in_ = in_ patchsys = self._options['patchsys']
if in_: if in_:
if hasattr(in_, 'read'):
tmpfile = in_
else:
fd = os.open(devnullpath, os.O_RDONLY)
tmpfile = os.fdopen(fd)
try: try:
self._oldin = (sys.stdin, os.dup(0)) self.in_ = FDCapture(0, tmpfile=tmpfile, now=False,
patchsys=patchsys)
self._options['in_'] = self.in_.tmpfile
except OSError: except OSError:
pass pass
if out: if out:
@ -202,7 +200,8 @@ class StdCaptureFD(Capture):
if hasattr(out, 'write'): if hasattr(out, 'write'):
tmpfile = out tmpfile = out
try: try:
self.out = FDCapture(1, tmpfile=tmpfile, now=False) self.out = FDCapture(1, tmpfile=tmpfile,
now=False, patchsys=patchsys)
self._options['out'] = self.out.tmpfile self._options['out'] = self.out.tmpfile
except OSError: except OSError:
pass pass
@ -214,28 +213,23 @@ class StdCaptureFD(Capture):
else: else:
tmpfile = None tmpfile = None
try: try:
self.err = FDCapture(2, tmpfile=tmpfile, now=False) self.err = FDCapture(2, tmpfile=tmpfile,
now=False, patchsys=patchsys)
self._options['err'] = self.err.tmpfile self._options['err'] = self.err.tmpfile
except OSError: except OSError:
pass pass
def startall(self): def startall(self):
if self.in_: if hasattr(self, 'in_'):
self.in_.start()
sys.stdin = DontReadFromInput() sys.stdin = DontReadFromInput()
fd = os.open(devnullpath, os.O_RDONLY)
os.dup2(fd, 0)
os.close(fd)
out = getattr(self, 'out', None) out = getattr(self, 'out', None)
if out: if out:
out.start() out.start()
if self.patchsys:
out.setasfile('stdout')
err = getattr(self, 'err', None) err = getattr(self, 'err', None)
if err: if err:
err.start() err.start()
if self.patchsys:
err.setasfile('stderr')
def resume(self): def resume(self):
""" resume capturing with original temp files. """ """ resume capturing with original temp files. """
@ -248,11 +242,8 @@ class StdCaptureFD(Capture):
outfile = self.out.done() outfile = self.out.done()
if hasattr(self, 'err') and not self.err.tmpfile.closed: if hasattr(self, 'err') and not self.err.tmpfile.closed:
errfile = self.err.done() errfile = self.err.done()
if hasattr(self, '_oldin'): if hasattr(self, 'in_'):
oldsys, oldfd = self._oldin tmpfile = self.in_.done()
os.dup2(oldfd, 0)
os.close(oldfd)
sys.stdin = oldsys
self._save() self._save()
return outfile, errfile return outfile, errfile

View File

@ -144,8 +144,7 @@ class TestFDCapture:
f.close() f.close()
def test_stderr(self): def test_stderr(self):
cap = py.io.FDCapture(2) cap = py.io.FDCapture(2, patchsys=True)
cap.setasfile('stderr')
print_("hello", file=sys.stderr) print_("hello", file=sys.stderr)
f = cap.done() f = cap.done()
s = f.read() s = f.read()