Support for tests created with functools.partial

Fix #811
This commit is contained in:
Bruno Oliveira
2015-07-02 23:13:59 -03:00
parent 330de0a93d
commit dcdc823dd2
2 changed files with 69 additions and 5 deletions

View File

@@ -18,10 +18,19 @@ callable = py.builtin.callable
# used to work around a python2 exception info leak
exc_clear = getattr(sys, 'exc_clear', lambda: None)
def getfslineno(obj):
# xxx let decorators etc specify a sane ordering
def get_real_func(obj):
"""gets the real function object of the (possibly) wrapped object by
functools.wraps or functools.partial.
"""
while hasattr(obj, "__wrapped__"):
obj = obj.__wrapped__
if isinstance(obj, py.std.functools.partial):
obj = obj.func
return obj
def getfslineno(obj):
# xxx let decorators etc specify a sane ordering
obj = get_real_func(obj)
if hasattr(obj, 'place_as'):
obj = obj.place_as
fslineno = py.code.getfslineno(obj)
@@ -594,7 +603,10 @@ class FunctionMixin(PyobjMixin):
def _prunetraceback(self, excinfo):
if hasattr(self, '_obj') and not self.config.option.fulltrace:
code = py.code.Code(self.obj)
if isinstance(self.obj, py.std.functools.partial):
code = py.code.Code(self.obj.func)
else:
code = py.code.Code(self.obj)
path, firstlineno = code.path, code.firstlineno
traceback = excinfo.traceback
ntraceback = traceback.cut(path=path, firstlineno=firstlineno)
@@ -1537,7 +1549,7 @@ class FixtureLookupError(LookupError):
for function in stack:
fspath, lineno = getfslineno(function)
try:
lines, _ = inspect.getsourcelines(function)
lines, _ = inspect.getsourcelines(get_real_func(function))
except IOError:
error_msg = "file %s, line %s: source code not available"
addline(error_msg % (fspath, lineno+1))
@@ -1937,7 +1949,15 @@ def getfuncargnames(function, startindex=None):
if realfunction != function:
startindex += num_mock_patch_args(function)
function = realfunction
argnames = inspect.getargs(py.code.getrawcode(function))[0]
if isinstance(function, py.std.functools.partial):
argnames = inspect.getargs(py.code.getrawcode(function.func))[0]
partial = function
argnames = argnames[len(partial.args):]
if partial.keywords:
for kw in partial.keywords:
argnames.remove(kw)
else:
argnames = inspect.getargs(py.code.getrawcode(function))[0]
defaults = getattr(function, 'func_defaults',
getattr(function, '__defaults__', None)) or ()
numdefaults = len(defaults)