diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index f670afe4f..9e53c79f4 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -24,6 +24,14 @@ def pytest_addoption(parser): expression information.""", ) + group = parser.getgroup("experimental") + group.addoption( + "--enable-assertion-pass-hook", + action="store_true", + help="Enables the pytest_assertion_pass hook." + "Make sure to delete any previously generated pyc cache files.", + ) + def register_assert_rewrite(*names): """Register one or more module names to be rewritten on import. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index ca3f18cf3..5477927b7 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -745,7 +745,8 @@ class AssertionRewriter(ast.NodeVisitor): format_dict = ast.Dict(keys, list(current.values())) form = ast.BinOp(expl_expr, ast.Mod(), format_dict) name = "@py_format" + str(next(self.variable_counter)) - self.format_variables.append(name) + if getattr(self.config._ns, "enable_assertion_pass_hook", False): + self.format_variables.append(name) self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) return ast.Name(name, ast.Load()) @@ -780,7 +781,10 @@ class AssertionRewriter(ast.NodeVisitor): self.statements = [] self.variables = [] self.variable_counter = itertools.count() - self.format_variables = [] + + if getattr(self.config._ns, "enable_assertion_pass_hook", False): + self.format_variables = [] + self.stack = [] self.expl_stmts = [] self.push_format_context() @@ -793,41 +797,68 @@ class AssertionRewriter(ast.NodeVisitor): top_condition, module_path=self.module_path, lineno=assert_.lineno ) ) - negation = ast.UnaryOp(ast.Not(), top_condition) - msg = self.pop_format_context(ast.Str(explanation)) - if assert_.msg: - assertmsg = self.helper("_format_assertmsg", assert_.msg) - gluestr = "\n>assert " - else: - assertmsg = ast.Str("") - gluestr = "assert " - err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) - err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) - err_name = ast.Name("AssertionError", ast.Load()) - fmt = self.helper("_format_explanation", err_msg) - fmt_pass = self.helper("_format_explanation", msg) - exc = ast.Call(err_name, [fmt], []) - raise_ = ast.Raise(exc, None) - # Call to hook when passes - orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")") - hook_call_pass = ast.Expr( - self.helper( - "_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass + if getattr(self.config._ns, "enable_assertion_pass_hook", False): + ### Experimental pytest_assertion_pass hook + negation = ast.UnaryOp(ast.Not(), top_condition) + msg = self.pop_format_context(ast.Str(explanation)) + if assert_.msg: + assertmsg = self.helper("_format_assertmsg", assert_.msg) + gluestr = "\n>assert " + else: + assertmsg = ast.Str("") + gluestr = "assert " + err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) + err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) + err_name = ast.Name("AssertionError", ast.Load()) + fmt = self.helper("_format_explanation", err_msg) + fmt_pass = self.helper("_format_explanation", msg) + exc = ast.Call(err_name, [fmt], []) + raise_ = ast.Raise(exc, None) + # Call to hook when passes + orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")") + hook_call_pass = ast.Expr( + self.helper( + "_call_assertion_pass", + ast.Num(assert_.lineno), + ast.Str(orig), + fmt_pass, + ) ) - ) - # If any hooks implement assert_pass hook - hook_impl_test = ast.If( - self.helper("_check_if_assertionpass_impl"), [hook_call_pass], [] - ) - main_test = ast.If(negation, [raise_], [hook_impl_test]) + # If any hooks implement assert_pass hook + hook_impl_test = ast.If( + self.helper("_check_if_assertionpass_impl"), [hook_call_pass], [] + ) + main_test = ast.If(negation, [raise_], [hook_impl_test]) - self.statements.extend(self.expl_stmts) - self.statements.append(main_test) - if self.format_variables: - variables = [ast.Name(name, ast.Store()) for name in self.format_variables] - clear_format = ast.Assign(variables, _NameConstant(None)) - self.statements.append(clear_format) + self.statements.extend(self.expl_stmts) + self.statements.append(main_test) + if self.format_variables: + variables = [ + ast.Name(name, ast.Store()) for name in self.format_variables + ] + clear_format = ast.Assign(variables, _NameConstant(None)) + self.statements.append(clear_format) + else: + ### Original assertion rewriting + # Create failure message. + body = self.expl_stmts + negation = ast.UnaryOp(ast.Not(), top_condition) + self.statements.append(ast.If(negation, body, [])) + if assert_.msg: + assertmsg = self.helper("_format_assertmsg", assert_.msg) + explanation = "\n>assert " + explanation + else: + assertmsg = ast.Str("") + explanation = "assert " + explanation + template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) + msg = self.pop_format_context(template) + fmt = self.helper("_format_explanation", msg) + err_name = ast.Name("AssertionError", ast.Load()) + exc = ast.Call(err_name, [fmt], []) + raise_ = ast.Raise(exc, None) + + body.append(raise_) # Clear temporary variables by setting them to None. if self.variables: variables = [ast.Name(name, ast.Store()) for name in self.variables] diff --git a/src/_pytest/config/__init__.py b/src/_pytest/config/__init__.py index e6de86c36..f947a4c32 100644 --- a/src/_pytest/config/__init__.py +++ b/src/_pytest/config/__init__.py @@ -746,8 +746,10 @@ class Config: and find all the installed plugins to mark them for rewriting by the importhook. """ - ns, unknown_args = self._parser.parse_known_and_unknown_args(args) - mode = getattr(ns, "assertmode", "plain") + # Saving _ns so it can be used for other assertion rewriting purposes + # e.g. experimental assertion pass hook + self._ns, self._unknown_args = self._parser.parse_known_and_unknown_args(args) + mode = getattr(self._ns, "assertmode", "plain") if mode == "rewrite": try: hook = _pytest.assertion.install_importhook(self) diff --git a/src/_pytest/hookspec.py b/src/_pytest/hookspec.py index c22b4c12a..5cb1d9ce5 100644 --- a/src/_pytest/hookspec.py +++ b/src/_pytest/hookspec.py @@ -503,6 +503,8 @@ def pytest_assertion_pass(item, lineno, orig, expl): This hook is still *experimental*, so its parameters or even the hook itself might be changed/removed without warning in any future pytest release. + It should be enabled using the `--enable-assertion-pass-hook` command line option. + If you find this hook useful, please share your feedback opening an issue. """ diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index d3c50511f..e74e6df83 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1326,8 +1326,7 @@ class TestAssertionPass: assert a+b == c+d """ ) - result = testdir.runpytest() - print(testdir.tmpdir) + result = testdir.runpytest("--enable-assertion-pass-hook") result.stdout.fnmatch_lines( "*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*" ) @@ -1343,6 +1342,38 @@ class TestAssertionPass: _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass ) + testdir.makepyfile( + """ + def test_simple(): + a=1 + b=2 + c=3 + d=0 + + assert a+b == c+d + """ + ) + result = testdir.runpytest("--enable-assertion-pass-hook") + result.assert_outcomes(passed=1) + + def test_hook_not_called_without_cmd_option(self, testdir, monkeypatch): + """Assertion pass should not be called (and hence formatting should + not occur) if there is no hook declared for pytest_assertion_pass""" + + def raise_on_assertionpass(*_, **__): + raise Exception("Assertion passed called when it shouldn't!") + + monkeypatch.setattr( + _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass + ) + + testdir.makeconftest( + """ + def pytest_assertion_pass(item, lineno, orig, expl): + raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno)) + """ + ) + testdir.makepyfile( """ def test_simple():