From f7747f5dd652694ce1357ea633787cb772627bef Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 10 Jul 2019 13:27:28 +0300 Subject: [PATCH 01/10] Remove references to old-style classes in a couple error messages These don't exist in Python 3. --- src/_pytest/python_api.py | 5 +---- src/_pytest/recwarn.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 82305bb1c..7c63a3588 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -647,10 +647,7 @@ def raises(expected_exception, *args, match=None, **kwargs): for exc in filterfalse( inspect.isclass, always_iterable(expected_exception, BASE_TYPE) ): - msg = ( - "exceptions must be old-style classes or" - " derived from BaseException, not %s" - ) + msg = "exceptions must be derived from BaseException, not %s" raise TypeError(msg % type(exc)) message = "DID NOT RAISE {}".format(expected_exception) diff --git a/src/_pytest/recwarn.py b/src/_pytest/recwarn.py index 3ab83d1e3..b124c69d5 100644 --- a/src/_pytest/recwarn.py +++ b/src/_pytest/recwarn.py @@ -151,7 +151,7 @@ class WarningsChecker(WarningsRecorder): def __init__(self, expected_warning=None, match_expr=None): super().__init__() - msg = "exceptions must be old-style classes or derived from Warning, not %s" + msg = "exceptions must be derived from Warning, not %s" if isinstance(expected_warning, tuple): for exc in expected_warning: if not inspect.isclass(exc): From 35a57a0dfbe2dd25f82c66e6c3388fecbac4eb72 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Sun, 14 Jul 2019 11:12:47 +0300 Subject: [PATCH 02/10] Use flake8's extend-ignore instead of ignore extend-ignore adds ignores in addition to flake8's existing ignores. The default ignores currently are: E121,E123,E126,E226,E24,E704,W503,W504 --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 832b97298..87d8141a3 100644 --- a/tox.ini +++ b/tox.ini @@ -156,7 +156,7 @@ markers = [flake8] max-line-length = 120 -ignore = E203,W503 +extend-ignore = E203 [isort] ; This config mimics what reorder-python-imports does. From 866904ab80605351c97c922db76c0586924403dd Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Sun, 14 Jul 2019 11:39:30 +0300 Subject: [PATCH 03/10] Revert "Let context-managers for raises and warns handle unknown keyword arguments" This reverts commit dfe54cd82f55f17f3c9e6e078325f306a046b93b. The idea in the commit was to simplify the code by removing the check and instead letting it TypeError which has the same effect. However this type error is caught by mypy, and rather than ignoring the error we think it's better and clearer to go back to the previous explicit check. --- src/_pytest/python_api.py | 9 ++++++--- src/_pytest/recwarn.py | 7 ++++++- testing/python/raises.py | 6 ++++++ testing/test_recwarn.py | 6 ++++++ 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 7c63a3588..aae5ced33 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -653,9 +653,12 @@ def raises(expected_exception, *args, match=None, **kwargs): message = "DID NOT RAISE {}".format(expected_exception) if not args: - return RaisesContext( - expected_exception, message=message, match_expr=match, **kwargs - ) + if kwargs: + msg = "Unexpected keyword arguments passed to pytest.raises: " + msg += ", ".join(sorted(kwargs)) + msg += "\nUse context-manager form instead?" + raise TypeError(msg) + return RaisesContext(expected_exception, message, match) else: func = args[0] if not callable(func): diff --git a/src/_pytest/recwarn.py b/src/_pytest/recwarn.py index b124c69d5..7e772aa35 100644 --- a/src/_pytest/recwarn.py +++ b/src/_pytest/recwarn.py @@ -76,7 +76,12 @@ def warns(expected_warning, *args, match=None, **kwargs): """ __tracebackhide__ = True if not args: - return WarningsChecker(expected_warning, match_expr=match, **kwargs) + if kwargs: + msg = "Unexpected keyword arguments passed to pytest.warns: " + msg += ", ".join(sorted(kwargs)) + msg += "\nUse context-manager form instead?" + raise TypeError(msg) + return WarningsChecker(expected_warning, match_expr=match) else: func = args[0] if not callable(func): diff --git a/testing/python/raises.py b/testing/python/raises.py index 1f5594c8a..668be57fc 100644 --- a/testing/python/raises.py +++ b/testing/python/raises.py @@ -248,3 +248,9 @@ class TestRaises: with pytest.raises(CrappyClass()): pass assert "via __class__" in excinfo.value.args[0] + + def test_raises_context_manager_with_kwargs(self): + with pytest.raises(TypeError) as excinfo: + with pytest.raises(Exception, foo="bar"): + pass + assert "Unexpected keyword arguments" in str(excinfo.value) diff --git a/testing/test_recwarn.py b/testing/test_recwarn.py index 65fdd1682..208dc5b44 100644 --- a/testing/test_recwarn.py +++ b/testing/test_recwarn.py @@ -374,3 +374,9 @@ class TestWarns: assert f() == 10 assert pytest.warns(UserWarning, f) == 10 assert pytest.warns(UserWarning, f) == 10 + + def test_warns_context_manager_with_kwargs(self): + with pytest.raises(TypeError) as excinfo: + with pytest.warns(UserWarning, foo="bar"): + pass + assert "Unexpected keyword arguments" in str(excinfo.value) From d7ee3dac2ce7ac8c335e95fd643b05ade9bec897 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 10 Jul 2019 11:23:02 +0300 Subject: [PATCH 04/10] Type-annotate pytest.{exit,skip,fail,xfail,importorskip} --- src/_pytest/outcomes.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/_pytest/outcomes.py b/src/_pytest/outcomes.py index c7e26f5cc..aaf0b35fb 100644 --- a/src/_pytest/outcomes.py +++ b/src/_pytest/outcomes.py @@ -3,21 +3,26 @@ exception classes and constants handling test outcomes as well as functions creating them """ import sys +from typing import Any +from typing import Optional from packaging.version import Version +if False: # TYPE_CHECKING + from typing import NoReturn + class OutcomeException(BaseException): """ OutcomeException and its subclass instances indicate and contain info about test and collection outcomes. """ - def __init__(self, msg=None, pytrace=True): + def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None: BaseException.__init__(self, msg) self.msg = msg self.pytrace = pytrace - def __repr__(self): + def __repr__(self) -> str: if self.msg: val = self.msg if isinstance(val, bytes): @@ -36,7 +41,12 @@ class Skipped(OutcomeException): # in order to have Skipped exception printing shorter/nicer __module__ = "builtins" - def __init__(self, msg=None, pytrace=True, allow_module_level=False): + def __init__( + self, + msg: Optional[str] = None, + pytrace: bool = True, + allow_module_level: bool = False, + ) -> None: OutcomeException.__init__(self, msg=msg, pytrace=pytrace) self.allow_module_level = allow_module_level @@ -50,7 +60,9 @@ class Failed(OutcomeException): class Exit(Exception): """ raised for immediate program exits (no tracebacks/summaries)""" - def __init__(self, msg="unknown reason", returncode=None): + def __init__( + self, msg: str = "unknown reason", returncode: Optional[int] = None + ) -> None: self.msg = msg self.returncode = returncode super().__init__(msg) @@ -59,7 +71,7 @@ class Exit(Exception): # exposed helper methods -def exit(msg, returncode=None): +def exit(msg: str, returncode: Optional[int] = None) -> "NoReturn": """ Exit testing process. @@ -74,7 +86,7 @@ def exit(msg, returncode=None): exit.Exception = Exit # type: ignore -def skip(msg="", *, allow_module_level=False): +def skip(msg: str = "", *, allow_module_level: bool = False) -> "NoReturn": """ Skip an executing test with the given message. @@ -101,7 +113,7 @@ def skip(msg="", *, allow_module_level=False): skip.Exception = Skipped # type: ignore -def fail(msg="", pytrace=True): +def fail(msg: str = "", pytrace: bool = True) -> "NoReturn": """ Explicitly fail an executing test with the given message. @@ -121,7 +133,7 @@ class XFailed(Failed): """ raised from an explicit call to pytest.xfail() """ -def xfail(reason=""): +def xfail(reason: str = "") -> "NoReturn": """ Imperatively xfail an executing test or setup functions with the given reason. @@ -139,7 +151,9 @@ def xfail(reason=""): xfail.Exception = XFailed # type: ignore -def importorskip(modname, minversion=None, reason=None): +def importorskip( + modname: str, minversion: Optional[str] = None, reason: Optional[str] = None +) -> Any: """Imports and returns the requested module ``modname``, or skip the current test if the module cannot be imported. From 2dca68b863441ebe7e2ce16dcad9aaf6201a8fe7 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 10 Jul 2019 14:36:07 +0300 Subject: [PATCH 05/10] Type-annotate pytest.warns --- src/_pytest/recwarn.py | 110 ++++++++++++++++++++++++++++++++--------- 1 file changed, 87 insertions(+), 23 deletions(-) diff --git a/src/_pytest/recwarn.py b/src/_pytest/recwarn.py index 7e772aa35..19e3938c3 100644 --- a/src/_pytest/recwarn.py +++ b/src/_pytest/recwarn.py @@ -1,11 +1,23 @@ """ recording warnings during test function execution. """ -import inspect import re import warnings +from types import TracebackType +from typing import Any +from typing import Callable +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Pattern +from typing import Tuple +from typing import Union from _pytest.fixtures import yield_fixture from _pytest.outcomes import fail +if False: # TYPE_CHECKING + from typing import Type + @yield_fixture def recwarn(): @@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs): return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs) -def warns(expected_warning, *args, match=None, **kwargs): +@overload +def warns( + expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + *, + match: Optional[Union[str, Pattern]] = ... +) -> "WarningsChecker": + ... # pragma: no cover + + +@overload +def warns( + expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + func: Callable, + *args: Any, + match: Optional[Union[str, Pattern]] = ..., + **kwargs: Any +) -> Union[Any]: + ... # pragma: no cover + + +def warns( + expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], + *args: Any, + match: Optional[Union[str, Pattern]] = None, + **kwargs: Any +) -> Union["WarningsChecker", Any]: r"""Assert that code raises a particular class of warning. Specifically, the parameter ``expected_warning`` can be a warning class or @@ -101,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings): def __init__(self): super().__init__(record=True) self._entered = False - self._list = [] + self._list = [] # type: List[warnings._Record] @property - def list(self): + def list(self) -> List["warnings._Record"]: """The list of recorded warnings.""" return self._list - def __getitem__(self, i): + def __getitem__(self, i: int) -> "warnings._Record": """Get a recorded warning by index.""" return self._list[i] - def __iter__(self): + def __iter__(self) -> Iterator["warnings._Record"]: """Iterate through the recorded warnings.""" return iter(self._list) - def __len__(self): + def __len__(self) -> int: """The number of recorded warnings.""" return len(self._list) - def pop(self, cls=Warning): + def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record": """Pop the first recorded warning, raise exception if not exists.""" for i, w in enumerate(self._list): if issubclass(w.category, cls): @@ -128,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings): __tracebackhide__ = True raise AssertionError("%r not found in warning list" % cls) - def clear(self): + def clear(self) -> None: """Clear the list of recorded warnings.""" self._list[:] = [] - def __enter__(self): + # Type ignored because it doesn't exactly warnings.catch_warnings.__enter__ + # -- it returns a List but we only emulate one. + def __enter__(self) -> "WarningsRecorder": # type: ignore if self._entered: __tracebackhide__ = True raise RuntimeError("Cannot enter %r twice" % self) - self._list = super().__enter__() + _list = super().__enter__() + # record=True means it's None. + assert _list is not None + self._list = _list warnings.simplefilter("always") return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: Optional["Type[BaseException]"], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: if not self._entered: __tracebackhide__ = True raise RuntimeError("Cannot exit %r without entering first" % self) - super().__exit__(*exc_info) + super().__exit__(exc_type, exc_val, exc_tb) # Built-in catch_warnings does not reset entered state so we do it # manually here for this context manager to become reusable. self._entered = False + return False + class WarningsChecker(WarningsRecorder): - def __init__(self, expected_warning=None, match_expr=None): + def __init__( + self, + expected_warning: Optional[ + Union["Type[Warning]", Tuple["Type[Warning]", ...]] + ] = None, + match_expr: Optional[Union[str, Pattern]] = None, + ) -> None: super().__init__() msg = "exceptions must be derived from Warning, not %s" - if isinstance(expected_warning, tuple): + if expected_warning is None: + expected_warning_tup = None + elif isinstance(expected_warning, tuple): for exc in expected_warning: - if not inspect.isclass(exc): + if not issubclass(exc, Warning): raise TypeError(msg % type(exc)) - elif inspect.isclass(expected_warning): - expected_warning = (expected_warning,) - elif expected_warning is not None: + expected_warning_tup = expected_warning + elif issubclass(expected_warning, Warning): + expected_warning_tup = (expected_warning,) + else: raise TypeError(msg % type(expected_warning)) - self.expected_warning = expected_warning + self.expected_warning = expected_warning_tup self.match_expr = match_expr - def __exit__(self, *exc_info): - super().__exit__(*exc_info) + def __exit__( + self, + exc_type: Optional["Type[BaseException]"], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + super().__exit__(exc_type, exc_val, exc_tb) __tracebackhide__ = True # only check if we're not currently handling an exception - if all(a is None for a in exc_info): + if exc_type is None and exc_val is None and exc_tb is None: if self.expected_warning is not None: if not any(issubclass(r.category, self.expected_warning) for r in self): __tracebackhide__ = True @@ -200,3 +263,4 @@ class WarningsChecker(WarningsRecorder): [each.message for each in self], ) ) + return False From 55a570e5135cc8e08f242794b2b7a38677d81838 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 10 Jul 2019 12:30:29 +0300 Subject: [PATCH 06/10] Type-annotate ExceptionInfo --- src/_pytest/_code/code.py | 75 ++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 25 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index d63c010e4..d9b06ffd9 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -5,6 +5,11 @@ import traceback from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from traceback import format_exception_only +from types import TracebackType +from typing import Optional +from typing import Pattern +from typing import Tuple +from typing import Union from weakref import ref import attr @@ -15,6 +20,9 @@ import _pytest from _pytest._io.saferepr import safeformat from _pytest._io.saferepr import saferepr +if False: # TYPE_CHECKING + from typing import Type + class Code: """ wrapper around Python code objects """ @@ -379,12 +387,14 @@ class ExceptionInfo: _assert_start_repr = "AssertionError('assert " - _excinfo = attr.ib() - _striptext = attr.ib(default="") - _traceback = attr.ib(default=None) + _excinfo = attr.ib( + type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]] + ) + _striptext = attr.ib(type=str, default="") + _traceback = attr.ib(type=Optional[Traceback], default=None) @classmethod - def from_current(cls, exprinfo=None): + def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo": """returns an ExceptionInfo matching the current traceback .. warning:: @@ -396,8 +406,11 @@ class ExceptionInfo: strip ``AssertionError`` from the output, defaults to the exception message/``__str__()`` """ - tup = sys.exc_info() - assert tup[0] is not None, "no current exception" + tup_ = sys.exc_info() + assert tup_[0] is not None, "no current exception" + assert tup_[1] is not None, "no current exception" + assert tup_[2] is not None, "no current exception" + tup = (tup_[0], tup_[1], tup_[2]) _striptext = "" if exprinfo is None and isinstance(tup[1], AssertionError): exprinfo = getattr(tup[1], "msg", None) @@ -409,48 +422,60 @@ class ExceptionInfo: return cls(tup, _striptext) @classmethod - def for_later(cls): + def for_later(cls) -> "ExceptionInfo": """return an unfilled ExceptionInfo """ return cls(None) @property - def type(self): + def type(self) -> "Type[BaseException]": """the exception class""" + assert ( + self._excinfo is not None + ), ".type can only be used after the context manager exits" return self._excinfo[0] @property - def value(self): + def value(self) -> BaseException: """the exception value""" + assert ( + self._excinfo is not None + ), ".value can only be used after the context manager exits" return self._excinfo[1] @property - def tb(self): + def tb(self) -> TracebackType: """the exception raw traceback""" + assert ( + self._excinfo is not None + ), ".tb can only be used after the context manager exits" return self._excinfo[2] @property - def typename(self): + def typename(self) -> str: """the type name of the exception""" + assert ( + self._excinfo is not None + ), ".typename can only be used after the context manager exits" return self.type.__name__ @property - def traceback(self): + def traceback(self) -> Traceback: """the traceback""" if self._traceback is None: self._traceback = Traceback(self.tb, excinfo=ref(self)) return self._traceback @traceback.setter - def traceback(self, value): + def traceback(self, value: Traceback) -> None: self._traceback = value - def __repr__(self): + def __repr__(self) -> str: if self._excinfo is None: return "" return "" % (self.typename, len(self.traceback)) - def exconly(self, tryshort=False): + def exconly(self, tryshort: bool = False) -> str: """ return the exception as a string when 'tryshort' resolves to True, and the exception is a @@ -466,11 +491,11 @@ class ExceptionInfo: text = text[len(self._striptext) :] return text - def errisinstance(self, exc): + def errisinstance(self, exc: "Type[BaseException]") -> bool: """ return True if the exception is an instance of exc """ return isinstance(self.value, exc) - def _getreprcrash(self): + def _getreprcrash(self) -> "ReprFileLocation": exconly = self.exconly(tryshort=True) entry = self.traceback.getcrashentry() path, lineno = entry.frame.code.raw.co_filename, entry.lineno @@ -478,13 +503,13 @@ class ExceptionInfo: def getrepr( self, - showlocals=False, - style="long", - abspath=False, - tbfilter=True, - funcargs=False, - truncate_locals=True, - chain=True, + showlocals: bool = False, + style: str = "long", + abspath: bool = False, + tbfilter: bool = True, + funcargs: bool = False, + truncate_locals: bool = True, + chain: bool = True, ): """ Return str()able representation of this exception info. @@ -535,7 +560,7 @@ class ExceptionInfo: ) return fmt.repr_excinfo(self) - def match(self, regexp): + def match(self, regexp: Union[str, Pattern]) -> bool: """ Check whether the regular expression 'regexp' is found in the string representation of the exception using ``re.search``. If it matches From 56dcc9e1f884dc9f5f699c975a303cb0a97ccfa9 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 10 Jul 2019 11:28:43 +0300 Subject: [PATCH 07/10] Type-annotate pytest.raises --- src/_pytest/python_api.py | 63 ++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index aae5ced33..9ede24df6 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -7,6 +7,13 @@ from collections.abc import Sized from decimal import Decimal from itertools import filterfalse from numbers import Number +from types import TracebackType +from typing import Any +from typing import Callable +from typing import Optional +from typing import overload +from typing import Pattern +from typing import Tuple from typing import Union from more_itertools.more import always_iterable @@ -15,6 +22,9 @@ import _pytest._code from _pytest.compat import STRING_TYPES from _pytest.outcomes import fail +if False: # TYPE_CHECKING + from typing import Type # noqa: F401 (used in type string) + BASE_TYPE = (type, STRING_TYPES) @@ -528,7 +538,32 @@ def _is_numpy_array(obj): # builtin pytest.raises helper -def raises(expected_exception, *args, match=None, **kwargs): +@overload +def raises( + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + *, + match: Optional[Union[str, Pattern]] = ... +) -> "RaisesContext": + ... # pragma: no cover + + +@overload +def raises( + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + func: Callable, + *args: Any, + match: Optional[str] = ..., + **kwargs: Any +) -> Optional[_pytest._code.ExceptionInfo]: + ... # pragma: no cover + + +def raises( + expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + *args: Any, + match: Optional[Union[str, Pattern]] = None, + **kwargs: Any +) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: r""" Assert that a code block/function call raises ``expected_exception`` or raise a failure exception otherwise. @@ -676,21 +711,35 @@ raises.Exception = fail.Exception # type: ignore class RaisesContext: - def __init__(self, expected_exception, message, match_expr): + def __init__( + self, + expected_exception: Union[ + "Type[BaseException]", Tuple["Type[BaseException]", ...] + ], + message: str, + match_expr: Optional[Union[str, Pattern]] = None, + ) -> None: self.expected_exception = expected_exception self.message = message self.match_expr = match_expr - self.excinfo = None + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] - def __enter__(self): + def __enter__(self) -> _pytest._code.ExceptionInfo: self.excinfo = _pytest._code.ExceptionInfo.for_later() return self.excinfo - def __exit__(self, *tp): + def __exit__( + self, + exc_type: Optional["Type[BaseException]"], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: __tracebackhide__ = True - if tp[0] is None: + if exc_type is None: fail(self.message) - self.excinfo.__init__(tp) + assert self.excinfo is not None + # Type ignored because mypy doesn't like calling __init__ directly like this. + self.excinfo.__init__((exc_type, exc_val, exc_tb)) # type: ignore suppress_exception = issubclass(self.excinfo.type, self.expected_exception) if self.match_expr is not None and suppress_exception: self.excinfo.match(self.match_expr) From 14bf4cdf44be4d8e2482b1f2b9cafeba06c03550 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 10 Jul 2019 20:12:41 +0300 Subject: [PATCH 08/10] Make ExceptionInfo generic in the exception type This way, in with pytest.raises(ValueError) as cm: ... cm.value is a ValueError and not a BaseException. --- src/_pytest/_code/code.py | 21 +++++++++++++-------- src/_pytest/python_api.py | 33 ++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index d9b06ffd9..203e90287 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -6,9 +6,11 @@ from inspect import CO_VARARGS from inspect import CO_VARKEYWORDS from traceback import format_exception_only from types import TracebackType +from typing import Generic from typing import Optional from typing import Pattern from typing import Tuple +from typing import TypeVar from typing import Union from weakref import ref @@ -379,22 +381,25 @@ co_equal = compile( ) +_E = TypeVar("_E", bound=BaseException) + + @attr.s(repr=False) -class ExceptionInfo: +class ExceptionInfo(Generic[_E]): """ wraps sys.exc_info() objects and offers help for navigating the traceback. """ _assert_start_repr = "AssertionError('assert " - _excinfo = attr.ib( - type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]] - ) + _excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]]) _striptext = attr.ib(type=str, default="") _traceback = attr.ib(type=Optional[Traceback], default=None) @classmethod - def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo": + def from_current( + cls, exprinfo: Optional[str] = None + ) -> "ExceptionInfo[BaseException]": """returns an ExceptionInfo matching the current traceback .. warning:: @@ -422,13 +427,13 @@ class ExceptionInfo: return cls(tup, _striptext) @classmethod - def for_later(cls) -> "ExceptionInfo": + def for_later(cls) -> "ExceptionInfo[_E]": """return an unfilled ExceptionInfo """ return cls(None) @property - def type(self) -> "Type[BaseException]": + def type(self) -> "Type[_E]": """the exception class""" assert ( self._excinfo is not None @@ -436,7 +441,7 @@ class ExceptionInfo: return self._excinfo[0] @property - def value(self) -> BaseException: + def value(self) -> _E: """the exception value""" assert ( self._excinfo is not None diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 9ede24df6..7ca545878 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -10,10 +10,13 @@ from numbers import Number from types import TracebackType from typing import Any from typing import Callable +from typing import cast +from typing import Generic from typing import Optional from typing import overload from typing import Pattern from typing import Tuple +from typing import TypeVar from typing import Union from more_itertools.more import always_iterable @@ -537,33 +540,35 @@ def _is_numpy_array(obj): # builtin pytest.raises helper +_E = TypeVar("_E", bound=BaseException) + @overload def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], *, match: Optional[Union[str, Pattern]] = ... -) -> "RaisesContext": +) -> "RaisesContext[_E]": ... # pragma: no cover @overload def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], func: Callable, *args: Any, match: Optional[str] = ..., **kwargs: Any -) -> Optional[_pytest._code.ExceptionInfo]: +) -> Optional[_pytest._code.ExceptionInfo[_E]]: ... # pragma: no cover def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], *args: Any, match: Optional[Union[str, Pattern]] = None, **kwargs: Any -) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: +) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]: r""" Assert that a code block/function call raises ``expected_exception`` or raise a failure exception otherwise. @@ -703,28 +708,30 @@ def raises( try: func(*args[1:], **kwargs) except expected_exception: - return _pytest._code.ExceptionInfo.from_current() + # Cast to narrow the type to expected_exception (_E). + return cast( + _pytest._code.ExceptionInfo[_E], + _pytest._code.ExceptionInfo.from_current(), + ) fail(message) raises.Exception = fail.Exception # type: ignore -class RaisesContext: +class RaisesContext(Generic[_E]): def __init__( self, - expected_exception: Union[ - "Type[BaseException]", Tuple["Type[BaseException]", ...] - ], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], message: str, match_expr: Optional[Union[str, Pattern]] = None, ) -> None: self.expected_exception = expected_exception self.message = message self.match_expr = match_expr - self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]] - def __enter__(self) -> _pytest._code.ExceptionInfo: + def __enter__(self) -> _pytest._code.ExceptionInfo[_E]: self.excinfo = _pytest._code.ExceptionInfo.for_later() return self.excinfo From 3f1fb625844a38871d7bc357d290f39ce87039ca Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Thu, 11 Jul 2019 14:36:13 +0300 Subject: [PATCH 09/10] Rework ExceptionInfo to not require manual __init__ call Mypy doesn't like calling __init__() in this way. --- src/_pytest/_code/code.py | 5 +++++ src/_pytest/python_api.py | 14 +++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 203e90287..07ed8066c 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -432,6 +432,11 @@ class ExceptionInfo(Generic[_E]): """ return cls(None) + def fill_unfilled(self, exc_info: Tuple["Type[_E]", _E, TracebackType]) -> None: + """fill an unfilled ExceptionInfo created with for_later()""" + assert self._excinfo is None, "ExceptionInfo was already filled" + self._excinfo = exc_info + @property def type(self) -> "Type[_E]": """the exception class""" diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 7ca545878..e3cb8a970 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -745,9 +745,13 @@ class RaisesContext(Generic[_E]): if exc_type is None: fail(self.message) assert self.excinfo is not None - # Type ignored because mypy doesn't like calling __init__ directly like this. - self.excinfo.__init__((exc_type, exc_val, exc_tb)) # type: ignore - suppress_exception = issubclass(self.excinfo.type, self.expected_exception) - if self.match_expr is not None and suppress_exception: + if not issubclass(exc_type, self.expected_exception): + return False + # Cast to narrow the exception type now that it's verified. + exc_info = cast( + Tuple["Type[_E]", _E, TracebackType], (exc_type, exc_val, exc_tb) + ) + self.excinfo.fill_unfilled(exc_info) + if self.match_expr is not None: self.excinfo.match(self.match_expr) - return suppress_exception + return True From 11f1f792226622fc69a99b2a1b567630130b09f8 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Sun, 14 Jul 2019 11:36:33 +0300 Subject: [PATCH 10/10] Allow creating ExceptionInfo from existing exc_info for better typing This way the ExceptionInfo generic parameter can be inferred from the passed-in exc_info. See for example the replaced cast(). --- src/_pytest/_code/code.py | 47 +++++++++++++++++++++++++----------- src/_pytest/python_api.py | 10 ++++---- testing/code/test_excinfo.py | 10 +++++++- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 07ed8066c..30ab01235 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -396,6 +396,33 @@ class ExceptionInfo(Generic[_E]): _striptext = attr.ib(type=str, default="") _traceback = attr.ib(type=Optional[Traceback], default=None) + @classmethod + def from_exc_info( + cls, + exc_info: Tuple["Type[_E]", "_E", TracebackType], + exprinfo: Optional[str] = None, + ) -> "ExceptionInfo[_E]": + """returns an ExceptionInfo for an existing exc_info tuple. + + .. warning:: + + Experimental API + + + :param exprinfo: a text string helping to determine if we should + strip ``AssertionError`` from the output, defaults + to the exception message/``__str__()`` + """ + _striptext = "" + if exprinfo is None and isinstance(exc_info[1], AssertionError): + exprinfo = getattr(exc_info[1], "msg", None) + if exprinfo is None: + exprinfo = saferepr(exc_info[1]) + if exprinfo and exprinfo.startswith(cls._assert_start_repr): + _striptext = "AssertionError: " + + return cls(exc_info, _striptext) + @classmethod def from_current( cls, exprinfo: Optional[str] = None @@ -411,20 +438,12 @@ class ExceptionInfo(Generic[_E]): strip ``AssertionError`` from the output, defaults to the exception message/``__str__()`` """ - tup_ = sys.exc_info() - assert tup_[0] is not None, "no current exception" - assert tup_[1] is not None, "no current exception" - assert tup_[2] is not None, "no current exception" - tup = (tup_[0], tup_[1], tup_[2]) - _striptext = "" - if exprinfo is None and isinstance(tup[1], AssertionError): - exprinfo = getattr(tup[1], "msg", None) - if exprinfo is None: - exprinfo = saferepr(tup[1]) - if exprinfo and exprinfo.startswith(cls._assert_start_repr): - _striptext = "AssertionError: " - - return cls(tup, _striptext) + tup = sys.exc_info() + assert tup[0] is not None, "no current exception" + assert tup[1] is not None, "no current exception" + assert tup[2] is not None, "no current exception" + exc_info = (tup[0], tup[1], tup[2]) + return cls.from_exc_info(exc_info) @classmethod def for_later(cls) -> "ExceptionInfo[_E]": diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index e3cb8a970..08426d69c 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -707,11 +707,11 @@ def raises( ) try: func(*args[1:], **kwargs) - except expected_exception: - # Cast to narrow the type to expected_exception (_E). - return cast( - _pytest._code.ExceptionInfo[_E], - _pytest._code.ExceptionInfo.from_current(), + except expected_exception as e: + # We just caught the exception - there is a traceback. + assert e.__traceback__ is not None + return _pytest._code.ExceptionInfo.from_exc_info( + (type(e), e, e.__traceback__) ) fail(message) diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index d7771833a..76f974957 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -58,7 +58,7 @@ class TWMock: fullwidth = 80 -def test_excinfo_simple(): +def test_excinfo_simple() -> None: try: raise ValueError except ValueError: @@ -66,6 +66,14 @@ def test_excinfo_simple(): assert info.type == ValueError +def test_excinfo_from_exc_info_simple(): + try: + raise ValueError + except ValueError as e: + info = _pytest._code.ExceptionInfo.from_exc_info((type(e), e, e.__traceback__)) + assert info.type == ValueError + + def test_excinfo_getstatement(): def g(): raise ValueError