extend Metafunc and write a pytest_generate_tests hook on the funcarg manager

which discovers factories
This commit is contained in:
holger krekel
2012-07-20 14:16:46 +02:00
parent e14459d45c
commit f358fe7154
3 changed files with 40 additions and 10 deletions

View File

@@ -443,6 +443,19 @@ class FuncargManager:
for plugin in plugins:
self.pytest_plugin_registered(plugin)
def pytest_generate_tests(self, metafunc):
for argname in metafunc.funcargnames:
faclist = self.getfactorylist(argname, metafunc.parentid,
metafunc.function, raising=False)
if faclist is None:
continue # will raise at setup time
for fac in faclist:
marker = getattr(fac, "funcarg", None)
if marker is not None:
params = marker.kwargs.get("params")
if params is not None:
metafunc.parametrize(argname, params, indirect=True)
def _parsefactories(self, holderobj, nodeid):
if holderobj in self._holderobjseen:
return
@@ -456,12 +469,14 @@ class FuncargManager:
obj = getattr(holderobj, name)
faclist.append((nodeid, obj))
def getfactorylist(self, argname, nodeid, function):
def getfactorylist(self, argname, nodeid, function, raising=True):
try:
factorydef = self.arg2facspec[argname]
except KeyError:
self._raiselookupfailed(argname, function, nodeid)
return self._matchfactories(factorydef, nodeid)
if raising:
self._raiselookupfailed(argname, function, nodeid)
else:
return self._matchfactories(factorydef, nodeid)
def _matchfactories(self, factorydef, nodeid):
l = []

View File

@@ -256,7 +256,7 @@ class PyCollector(PyobjMixin, pytest.Collector):
clscol = self.getparent(Class)
cls = clscol and clscol.obj or None
transfer_markers(funcobj, cls, module)
metafunc = Metafunc(funcobj, config=self.config,
metafunc = Metafunc(funcobj, parentid=self.nodeid, config=self.config,
cls=cls, module=module)
gentesthook = self.config.hook.pytest_generate_tests
extra = [module]
@@ -555,10 +555,12 @@ class CallSpec2(object):
class Metafunc:
def __init__(self, function, config=None, cls=None, module=None):
def __init__(self, function, config=None, cls=None, module=None,
parentid=""):
self.config = config
self.module = module
self.function = function
self.parentid = parentid
self.funcargnames = getfuncargnames(function,
startindex=int(cls is not None))
self.cls = cls
@@ -885,11 +887,6 @@ class FuncargRequest:
self.funcargnames = getfuncargnames(self.function)
self.parentid = pyfuncitem.parent.nodeid
def _discoverfactories(self):
for argname in self.funcargnames:
if argname not in self._funcargs:
self._getfaclist(argname)
def _getfaclist(self, argname):
faclist = self._name2factory.get(argname, None)
if faclist is None: