diff --git a/_pytest/main.py b/_pytest/main.py index 5771a1699..3383400e1 100644 --- a/_pytest/main.py +++ b/_pytest/main.py @@ -455,10 +455,6 @@ class Collector(Node): return str(exc.args[0]) return self._repr_failure_py(excinfo, style="short") - def _memocollect(self): - """ internal helper method to cache results of calling collect(). """ - return self._memoizedcall('_collected', lambda: list(self.collect())) - def _prunetraceback(self, excinfo): if hasattr(self, 'fspath'): traceback = excinfo.traceback diff --git a/_pytest/pytester.py b/_pytest/pytester.py index 651160cc7..da8996a04 100644 --- a/_pytest/pytester.py +++ b/_pytest/pytester.py @@ -10,6 +10,8 @@ import time import traceback from fnmatch import fnmatch +from weakref import WeakKeyDictionary + from py.builtin import print_ from _pytest._code import Source @@ -401,6 +403,7 @@ class Testdir: def __init__(self, request, tmpdir_factory): self.request = request + self._mod_collections = WeakKeyDictionary() # XXX remove duplication with tmpdir plugin basetmp = tmpdir_factory.ensuretemp("testdir") name = request.function.__name__ @@ -856,6 +859,7 @@ class Testdir: self.makepyfile(__init__ = "#") self.config = config = self.parseconfigure(path, *configargs) node = self.getnode(config, path) + return node def collect_by_name(self, modcol, name): @@ -870,7 +874,9 @@ class Testdir: :param name: The name of the node to return. """ - for colitem in modcol._memocollect(): + if modcol not in self._mod_collections: + self._mod_collections[modcol] = list(modcol.collect()) + for colitem in self._mod_collections[modcol]: if colitem.name == name: return colitem diff --git a/_pytest/runner.py b/_pytest/runner.py index d1a155415..6f1759f14 100644 --- a/_pytest/runner.py +++ b/_pytest/runner.py @@ -330,7 +330,9 @@ class TeardownErrorReport(BaseReport): self.__dict__.update(extra) def pytest_make_collect_report(collector): - call = CallInfo(collector._memocollect, "memocollect") + call = CallInfo( + lambda: list(collector.collect()), + 'collect') longrepr = None if not call.excinfo: outcome = "passed" @@ -568,4 +570,3 @@ def importorskip(modname, minversion=None): raise Skipped("module %r has __version__ %r, required is: %r" %( modname, verattr, minversion), allow_module_level=True) return mod -