diff --git a/changelog/3457.feature.rst b/changelog/3457.feature.rst new file mode 100644 index 000000000..c30943070 --- /dev/null +++ b/changelog/3457.feature.rst @@ -0,0 +1,4 @@ +New `pytest_assertion_pass `__ +hook, called with context information when an assertion *passes*. + +This hook is still **experimental** so use it with caution. diff --git a/changelog/3457.trivial.rst b/changelog/3457.trivial.rst new file mode 100644 index 000000000..f18887634 --- /dev/null +++ b/changelog/3457.trivial.rst @@ -0,0 +1 @@ +pytest now also depends on the `astor `__ package. diff --git a/doc/en/reference.rst b/doc/en/reference.rst index 6750b17f0..5abb01f50 100644 --- a/doc/en/reference.rst +++ b/doc/en/reference.rst @@ -665,15 +665,14 @@ Session related reporting hooks: .. autofunction:: pytest_fixture_post_finalizer .. autofunction:: pytest_warning_captured -And here is the central hook for reporting about -test execution: +Central hook for reporting about test execution: .. autofunction:: pytest_runtest_logreport -You can also use this hook to customize assertion representation for some -types: +Assertion related hooks: .. autofunction:: pytest_assertrepr_compare +.. autofunction:: pytest_assertion_pass Debugging/Interaction hooks diff --git a/setup.py b/setup.py index 4c87c6429..7d9532816 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ INSTALL_REQUIRES = [ "pluggy>=0.12,<1.0", "importlib-metadata>=0.12", "wcwidth", + "astor", ] diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index e52101c9f..126929b6a 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -23,6 +23,13 @@ def pytest_addoption(parser): test modules on import to provide assert expression information.""", ) + parser.addini( + "enable_assertion_pass_hook", + type="bool", + default=False, + help="Enables the pytest_assertion_pass hook." + "Make sure to delete any previously generated pyc cache files.", + ) def register_assert_rewrite(*names): @@ -92,7 +99,7 @@ def pytest_collection(session): def pytest_runtest_setup(item): - """Setup the pytest_assertrepr_compare hook + """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks The newinterpret and rewrite modules will use util._reprcompare if it exists to use custom reporting via the @@ -129,9 +136,19 @@ def pytest_runtest_setup(item): util._reprcompare = callbinrepr + if item.ihook.pytest_assertion_pass.get_hookimpls(): + + def call_assertion_pass_hook(lineno, expl, orig): + item.ihook.pytest_assertion_pass( + item=item, lineno=lineno, orig=orig, expl=expl + ) + + util._assertion_pass = call_assertion_pass_hook + def pytest_runtest_teardown(item): util._reprcompare = None + util._assertion_pass = None def pytest_sessionfinish(session): diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index e117ae251..8810c156c 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -10,6 +10,7 @@ import struct import sys import types +import astor import atomicwrites from _pytest._io.saferepr import saferepr @@ -134,7 +135,7 @@ class AssertionRewritingHook: co = _read_pyc(fn, pyc, state.trace) if co is None: state.trace("rewriting {!r}".format(fn)) - source_stat, co = _rewrite_test(fn) + source_stat, co = _rewrite_test(fn, self.config) if write: self._writing_pyc = True try: @@ -278,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc): return True -def _rewrite_test(fn): +def _rewrite_test(fn, config): """read and rewrite *fn* and return the code object.""" stat = os.stat(fn) with open(fn, "rb") as f: source = f.read() tree = ast.parse(source, filename=fn) - rewrite_asserts(tree, fn) + rewrite_asserts(tree, fn, config) co = compile(tree, fn, "exec", dont_inherit=True) return stat, co @@ -326,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None): return co -def rewrite_asserts(mod, module_path=None): +def rewrite_asserts(mod, module_path=None, config=None): """Rewrite the assert statements in mod.""" - AssertionRewriter(module_path).run(mod) + AssertionRewriter(module_path, config).run(mod) def _saferepr(obj): @@ -401,6 +402,17 @@ def _call_reprcompare(ops, results, expls, each_obj): return expl +def _call_assertion_pass(lineno, orig, expl): + if util._assertion_pass is not None: + util._assertion_pass(lineno=lineno, orig=orig, expl=expl) + + +def _check_if_assertion_pass_impl(): + """Checks if any plugins implement the pytest_assertion_pass hook + in order not to generate explanation unecessarily (might be expensive)""" + return True if util._assertion_pass else False + + unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} binop_map = { @@ -473,7 +485,8 @@ class AssertionRewriter(ast.NodeVisitor): original assert statement: it rewrites the test of an assertion to provide intermediate values and replace it with an if statement which raises an assertion error with a detailed explanation in - case the expression is false. + case the expression is false and calls pytest_assertion_pass hook + if expression is true. For this .visit_Assert() uses the visitor pattern to visit all the AST nodes of the ast.Assert.test field, each visit call returning @@ -491,9 +504,10 @@ class AssertionRewriter(ast.NodeVisitor): by statements. Variables are created using .variable() and have the form of "@py_assert0". - :on_failure: The AST statements which will be executed if the - assertion test fails. This is the code which will construct - the failure message and raises the AssertionError. + :expl_stmts: The AST statements which will be executed to get + data from the assertion. This is the code which will construct + the detailed assertion message that is used in the AssertionError + or for the pytest_assertion_pass hook. :explanation_specifiers: A dict filled by .explanation_param() with %-formatting placeholders and their corresponding @@ -509,9 +523,16 @@ class AssertionRewriter(ast.NodeVisitor): """ - def __init__(self, module_path): + def __init__(self, module_path, config): super().__init__() self.module_path = module_path + self.config = config + if config is not None: + self.enable_assertion_pass_hook = config.getini( + "enable_assertion_pass_hook" + ) + else: + self.enable_assertion_pass_hook = False def run(self, mod): """Find all assert statements in *mod* and rewrite them.""" @@ -642,7 +663,7 @@ class AssertionRewriter(ast.NodeVisitor): The expl_expr should be an ast.Str instance constructed from the %-placeholders created by .explanation_param(). This will - add the required code to format said string to .on_failure and + add the required code to format said string to .expl_stmts and return the ast.Name instance of the formatted string. """ @@ -653,7 +674,9 @@ class AssertionRewriter(ast.NodeVisitor): format_dict = ast.Dict(keys, list(current.values())) form = ast.BinOp(expl_expr, ast.Mod(), format_dict) name = "@py_format" + str(next(self.variable_counter)) - self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form)) + if self.enable_assertion_pass_hook: + self.format_variables.append(name) + self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) return ast.Name(name, ast.Load()) def generic_visit(self, node): @@ -687,8 +710,12 @@ class AssertionRewriter(ast.NodeVisitor): self.statements = [] self.variables = [] self.variable_counter = itertools.count() + + if self.enable_assertion_pass_hook: + self.format_variables = [] + self.stack = [] - self.on_failure = [] + self.expl_stmts = [] self.push_format_context() # Rewrite assert into a bunch of statements. top_condition, explanation = self.visit(assert_.test) @@ -699,24 +726,77 @@ class AssertionRewriter(ast.NodeVisitor): top_condition, module_path=self.module_path, lineno=assert_.lineno ) ) - # Create failure message. - body = self.on_failure - negation = ast.UnaryOp(ast.Not(), top_condition) - self.statements.append(ast.If(negation, body, [])) - if assert_.msg: - assertmsg = self.helper("_format_assertmsg", assert_.msg) - explanation = "\n>assert " + explanation - else: - assertmsg = ast.Str("") - explanation = "assert " + explanation - template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) - msg = self.pop_format_context(template) - fmt = self.helper("_format_explanation", msg) - err_name = ast.Name("AssertionError", ast.Load()) - exc = ast.Call(err_name, [fmt], []) - raise_ = ast.Raise(exc, None) - body.append(raise_) + if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook + negation = ast.UnaryOp(ast.Not(), top_condition) + msg = self.pop_format_context(ast.Str(explanation)) + + # Failed + if assert_.msg: + assertmsg = self.helper("_format_assertmsg", assert_.msg) + gluestr = "\n>assert " + else: + assertmsg = ast.Str("") + gluestr = "assert " + err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) + err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) + err_name = ast.Name("AssertionError", ast.Load()) + fmt = self.helper("_format_explanation", err_msg) + exc = ast.Call(err_name, [fmt], []) + raise_ = ast.Raise(exc, None) + statements_fail = [] + statements_fail.extend(self.expl_stmts) + statements_fail.append(raise_) + + # Passed + fmt_pass = self.helper("_format_explanation", msg) + orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")") + hook_call_pass = ast.Expr( + self.helper( + "_call_assertion_pass", + ast.Num(assert_.lineno), + ast.Str(orig), + fmt_pass, + ) + ) + # If any hooks implement assert_pass hook + hook_impl_test = ast.If( + self.helper("_check_if_assertion_pass_impl"), + self.expl_stmts + [hook_call_pass], + [], + ) + statements_pass = [hook_impl_test] + + # Test for assertion condition + main_test = ast.If(negation, statements_fail, statements_pass) + self.statements.append(main_test) + if self.format_variables: + variables = [ + ast.Name(name, ast.Store()) for name in self.format_variables + ] + clear_format = ast.Assign(variables, _NameConstant(None)) + self.statements.append(clear_format) + + else: # Original assertion rewriting + # Create failure message. + body = self.expl_stmts + negation = ast.UnaryOp(ast.Not(), top_condition) + self.statements.append(ast.If(negation, body, [])) + if assert_.msg: + assertmsg = self.helper("_format_assertmsg", assert_.msg) + explanation = "\n>assert " + explanation + else: + assertmsg = ast.Str("") + explanation = "assert " + explanation + template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) + msg = self.pop_format_context(template) + fmt = self.helper("_format_explanation", msg) + err_name = ast.Name("AssertionError", ast.Load()) + exc = ast.Call(err_name, [fmt], []) + raise_ = ast.Raise(exc, None) + + body.append(raise_) + # Clear temporary variables by setting them to None. if self.variables: variables = [ast.Name(name, ast.Store()) for name in self.variables] @@ -770,7 +850,7 @@ warn_explicit( app = ast.Attribute(expl_list, "append", ast.Load()) is_or = int(isinstance(boolop.op, ast.Or)) body = save = self.statements - fail_save = self.on_failure + fail_save = self.expl_stmts levels = len(boolop.values) - 1 self.push_format_context() # Process each operand, short-circuiting if needed. @@ -778,14 +858,14 @@ warn_explicit( if i: fail_inner = [] # cond is set in a prior loop iteration below - self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa - self.on_failure = fail_inner + self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa + self.expl_stmts = fail_inner self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) expl_format = self.pop_format_context(ast.Str(expl)) call = ast.Call(app, [expl_format], []) - self.on_failure.append(ast.Expr(call)) + self.expl_stmts.append(ast.Expr(call)) if i < levels: cond = res if is_or: @@ -794,7 +874,7 @@ warn_explicit( self.statements.append(ast.If(cond, inner, [])) self.statements = body = inner self.statements = save - self.on_failure = fail_save + self.expl_stmts = fail_save expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) expl = self.pop_format_context(expl_template) return ast.Name(res_var, ast.Load()), self.explanation_param(expl) diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index b808cb509..106c44a8a 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -12,6 +12,10 @@ from _pytest._io.saferepr import saferepr # DebugInterpreter. _reprcompare = None +# Works similarly as _reprcompare attribute. Is populated with the hook call +# when pytest_runtest_setup is called. +_assertion_pass = None + def format_explanation(explanation): """This formats an explanation diff --git a/src/_pytest/hookspec.py b/src/_pytest/hookspec.py index d40a36811..9e6d13fab 100644 --- a/src/_pytest/hookspec.py +++ b/src/_pytest/hookspec.py @@ -485,6 +485,42 @@ def pytest_assertrepr_compare(config, op, left, right): """ +def pytest_assertion_pass(item, lineno, orig, expl): + """ + **(Experimental)** + + Hook called whenever an assertion *passes*. + + Use this hook to do some processing after a passing assertion. + The original assertion information is available in the `orig` string + and the pytest introspected assertion information is available in the + `expl` string. + + This hook must be explicitly enabled by the ``enable_assertion_pass_hook`` + ini-file option: + + .. code-block:: ini + + [pytest] + enable_assertion_pass_hook=true + + You need to **clean the .pyc** files in your project directory and interpreter libraries + when enabling this option, as assertions will require to be re-written. + + :param _pytest.nodes.Item item: pytest item object of current test + :param int lineno: line number of the assert statement + :param string orig: string with original assertion + :param string expl: string with assert explanation + + .. note:: + + This hook is **experimental**, so its parameters or even the hook itself might + be changed/removed without warning in any future pytest release. + + If you find this hook useful, please share your feedback opening an issue. + """ + + # ------------------------------------------------------------------------- # hooks for influencing reporting (invoked from _pytest_terminal) # ------------------------------------------------------------------------- diff --git a/testing/acceptance_test.py b/testing/acceptance_test.py index 5567d994d..dbdf048a4 100644 --- a/testing/acceptance_test.py +++ b/testing/acceptance_test.py @@ -1101,7 +1101,10 @@ def test_fixture_values_leak(testdir): assert fix_of_test1_ref() is None """ ) - result = testdir.runpytest() + # Running on subprocess does not activate the HookRecorder + # which holds itself a reference to objects in case of the + # pytest_assert_reprcompare hook + result = testdir.runpytest_subprocess() result.stdout.fnmatch_lines(["* 2 passed *"]) diff --git a/testing/python/raises.py b/testing/python/raises.py index 89cef38f1..c9ede412a 100644 --- a/testing/python/raises.py +++ b/testing/python/raises.py @@ -202,6 +202,9 @@ class TestRaises: assert sys.exc_info() == (None, None, None) del t + # Make sure this does get updated in locals dict + # otherwise it could keep a reference + locals() # ensure the t instance is not stuck in a cyclic reference for o in gc.get_objects(): diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 258b5d3b7..8d1c7a5f0 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1332,3 +1332,115 @@ class TestEarlyRewriteBailout: ) result = testdir.runpytest() result.stdout.fnmatch_lines(["* 1 passed in *"]) + + +class TestAssertionPass: + def test_option_default(self, testdir): + config = testdir.parseconfig() + assert config.getini("enable_assertion_pass_hook") is False + + def test_hook_call(self, testdir): + testdir.makeconftest( + """ + def pytest_assertion_pass(item, lineno, orig, expl): + raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno)) + """ + ) + + testdir.makeini( + """ + [pytest] + enable_assertion_pass_hook = True + """ + ) + + testdir.makepyfile( + """ + def test_simple(): + a=1 + b=2 + c=3 + d=0 + + assert a+b == c+d + + # cover failing assertions with a message + def test_fails(): + assert False, "assert with message" + """ + ) + result = testdir.runpytest() + result.stdout.fnmatch_lines( + "*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*" + ) + + def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch): + """Assertion pass should not be called (and hence formatting should + not occur) if there is no hook declared for pytest_assertion_pass""" + + def raise_on_assertionpass(*_, **__): + raise Exception("Assertion passed called when it shouldn't!") + + monkeypatch.setattr( + _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass + ) + + testdir.makeini( + """ + [pytest] + enable_assertion_pass_hook = True + """ + ) + + testdir.makepyfile( + """ + def test_simple(): + a=1 + b=2 + c=3 + d=0 + + assert a+b == c+d + """ + ) + result = testdir.runpytest() + result.assert_outcomes(passed=1) + + def test_hook_not_called_without_cmd_option(self, testdir, monkeypatch): + """Assertion pass should not be called (and hence formatting should + not occur) if there is no hook declared for pytest_assertion_pass""" + + def raise_on_assertionpass(*_, **__): + raise Exception("Assertion passed called when it shouldn't!") + + monkeypatch.setattr( + _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass + ) + + testdir.makeconftest( + """ + def pytest_assertion_pass(item, lineno, orig, expl): + raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno)) + """ + ) + + testdir.makeini( + """ + [pytest] + enable_assertion_pass_hook = False + """ + ) + + testdir.makepyfile( + """ + def test_simple(): + a=1 + b=2 + c=3 + d=0 + + assert a+b == c+d + """ + ) + result = testdir.runpytest() + result.assert_outcomes(passed=1)