diff --git a/py/test/collect.py b/py/test/collect.py index 6a76e3b66..31eb30970 100644 --- a/py/test/collect.py +++ b/py/test/collect.py @@ -127,6 +127,12 @@ class Node(object): def listnames(self): return [x.name for x in self.listchain()] + def getparent(self, cls): + current = self + while current and not isinstance(current, cls): + current = current.parent + return current + def _getitembynames(self, namelist): cur = self for name in namelist: diff --git a/py/test/funcargs.py b/py/test/funcargs.py index 84531ad34..2d022cfab 100644 --- a/py/test/funcargs.py +++ b/py/test/funcargs.py @@ -76,7 +76,7 @@ class FuncargRequest: self._pyfuncitem = pyfuncitem self.argname = argname self.function = pyfuncitem.obj - self.module = pyfuncitem._getparent(py.test.collect.Module).obj + self.module = pyfuncitem.getparent(py.test.collect.Module).obj self.cls = getattr(self.function, 'im_class', None) self.instance = getattr(self.function, 'im_self', None) self.config = pyfuncitem.config @@ -116,7 +116,7 @@ class FuncargRequest: if scope == "function": return self._pyfuncitem elif scope == "module": - return self._pyfuncitem._getparent(py.test.collect.Module) + return self._pyfuncitem.getparent(py.test.collect.Module) raise ValueError("unknown finalization scope %r" %(scope,)) def addfinalizer(self, finalizer, scope="function"): diff --git a/py/test/pycollect.py b/py/test/pycollect.py index 3a8774a09..420ef1dac 100644 --- a/py/test/pycollect.py +++ b/py/test/pycollect.py @@ -37,12 +37,6 @@ class PyobjMixin(object): def _getobj(self): return getattr(self.parent.obj, self.name) - def _getparent(self, cls): - current = self - while current and not isinstance(current, cls): - current = current.parent - return current - def getmodpath(self, stopatmodule=True, includemodule=False): """ return python path relative to the containing module. """ chain = self.listchain() @@ -146,10 +140,10 @@ class PyCollectorMixin(PyobjMixin, py.test.collect.Collector): return self._genfunctions(name, obj) def _genfunctions(self, name, funcobj): - module = self._getparent(Module).obj + module = self.getparent(Module).obj # due to _buildname2items funcobj is the raw function, we need # to work to get at the class - clscol = self._getparent(Class) + clscol = self.getparent(Class) cls = clscol and clscol.obj or None metafunc = funcargs.Metafunc(funcobj, config=self.config, cls=cls, module=module) gentesthook = self.config.hook.pytest_generate_tests.clone(extralookup=module) diff --git a/py/test/testing/test_collect.py b/py/test/testing/test_collect.py index 2e25662e6..0d1d3112f 100644 --- a/py/test/testing/test_collect.py +++ b/py/test/testing/test_collect.py @@ -37,6 +37,24 @@ class TestCollector: assert [1,2,3] != fn assert modcol != fn + def test_getparent(self, testdir): + modcol = testdir.getmodulecol(""" + class TestClass: + def test_foo(): + pass + """) + cls = modcol.collect_by_name("TestClass") + fn = cls.collect_by_name("()").collect_by_name("test_foo") + + parent = fn.getparent(py.test.collect.Module) + assert parent is modcol + + parent = fn.getparent(py.test.collect.Function) + assert parent is fn + + parent = fn.getparent(py.test.collect.Class) + assert parent is cls + def test_totrail_and_back(self, tmpdir): a = tmpdir.ensure("a", dir=1) tmpdir.ensure("a", "__init__.py") diff --git a/py/test/testing/test_pycollect.py b/py/test/testing/test_pycollect.py index b662cabba..a4fe30538 100644 --- a/py/test/testing/test_pycollect.py +++ b/py/test/testing/test_pycollect.py @@ -217,7 +217,7 @@ class TestGenerator: class TestFunction: def test_getmodulecollector(self, testdir): item = testdir.getitem("def test_func(): pass") - modcol = item._getparent(py.test.collect.Module) + modcol = item.getparent(py.test.collect.Module) assert isinstance(modcol, py.test.collect.Module) assert hasattr(modcol.obj, 'test_func')