Assertion passed hook

This commit is contained in:
Victor Maryama 2019-06-24 16:09:39 +02:00
parent 3d01dd3adf
commit 9a89783fbb
10 changed files with 202 additions and 69 deletions

View File

@ -0,0 +1,2 @@
Adds ``pytest_assertion_pass`` hook, called with assertion context information
(original asssertion statement and pytest explanation) whenever an assertion passes.

View File

@ -13,6 +13,7 @@ INSTALL_REQUIRES = [
"pluggy>=0.12,<1.0", "pluggy>=0.12,<1.0",
"importlib-metadata>=0.12", "importlib-metadata>=0.12",
"wcwidth", "wcwidth",
"astor",
] ]

View File

@ -92,7 +92,7 @@ def pytest_collection(session):
def pytest_runtest_setup(item): def pytest_runtest_setup(item):
"""Setup the pytest_assertrepr_compare hook """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
The newinterpret and rewrite modules will use util._reprcompare if The newinterpret and rewrite modules will use util._reprcompare if
it exists to use custom reporting via the it exists to use custom reporting via the
@ -129,9 +129,15 @@ def pytest_runtest_setup(item):
util._reprcompare = callbinrepr util._reprcompare = callbinrepr
if item.ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno, expl, orig):
item.ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
util._assertion_pass = call_assertion_pass_hook
def pytest_runtest_teardown(item): def pytest_runtest_teardown(item):
util._reprcompare = None util._reprcompare = None
util._assertion_pass = None
def pytest_sessionfinish(session): def pytest_sessionfinish(session):

View File

@ -1,5 +1,6 @@
"""Rewrite assertion AST to produce nice error messages""" """Rewrite assertion AST to produce nice error messages"""
import ast import ast
import astor
import errno import errno
import imp import imp
import itertools import itertools
@ -357,6 +358,11 @@ def _rewrite_test(config, fn):
state.trace("failed to parse: {!r}".format(fn)) state.trace("failed to parse: {!r}".format(fn))
return None, None return None, None
rewrite_asserts(tree, fn, config) rewrite_asserts(tree, fn, config)
# TODO: REMOVE, THIS IS ONLY FOR DEBUG
with open(f'{str(fn)+"bak"}', "w", encoding="utf-8") as f:
f.write(astor.to_source(tree))
try: try:
co = compile(tree, fn.strpath, "exec", dont_inherit=True) co = compile(tree, fn.strpath, "exec", dont_inherit=True)
except SyntaxError: except SyntaxError:
@ -434,7 +440,7 @@ def _format_assertmsg(obj):
# contains a newline it gets escaped, however if an object has a # contains a newline it gets escaped, however if an object has a
# .__repr__() which contains newlines it does not get escaped. # .__repr__() which contains newlines it does not get escaped.
# However in either case we want to preserve the newline. # However in either case we want to preserve the newline.
replaces = [("\n", "\n~"), ("%", "%%")] replaces = [("\n", "\n~")]
if not isinstance(obj, str): if not isinstance(obj, str):
obj = saferepr(obj) obj = saferepr(obj)
replaces.append(("\\n", "\n~")) replaces.append(("\\n", "\n~"))
@ -478,6 +484,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
return expl return expl
def _call_assertion_pass(lineno, orig, expl):
if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
def _check_if_assertionpass_impl():
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False
unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
binop_map = { binop_map = {
@ -550,7 +567,8 @@ class AssertionRewriter(ast.NodeVisitor):
original assert statement: it rewrites the test of an assertion original assert statement: it rewrites the test of an assertion
to provide intermediate values and replace it with an if statement to provide intermediate values and replace it with an if statement
which raises an assertion error with a detailed explanation in which raises an assertion error with a detailed explanation in
case the expression is false. case the expression is false and calls pytest_assertion_pass hook
if expression is true.
For this .visit_Assert() uses the visitor pattern to visit all the For this .visit_Assert() uses the visitor pattern to visit all the
AST nodes of the ast.Assert.test field, each visit call returning AST nodes of the ast.Assert.test field, each visit call returning
@ -568,9 +586,10 @@ class AssertionRewriter(ast.NodeVisitor):
by statements. Variables are created using .variable() and by statements. Variables are created using .variable() and
have the form of "@py_assert0". have the form of "@py_assert0".
:on_failure: The AST statements which will be executed if the :expl_stmts: The AST statements which will be executed to get
assertion test fails. This is the code which will construct data from the assertion. This is the code which will construct
the failure message and raises the AssertionError. the detailed assertion message that is used in the AssertionError
or for the pytest_assertion_pass hook.
:explanation_specifiers: A dict filled by .explanation_param() :explanation_specifiers: A dict filled by .explanation_param()
with %-formatting placeholders and their corresponding with %-formatting placeholders and their corresponding
@ -720,7 +739,7 @@ class AssertionRewriter(ast.NodeVisitor):
The expl_expr should be an ast.Str instance constructed from The expl_expr should be an ast.Str instance constructed from
the %-placeholders created by .explanation_param(). This will the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .on_failure and add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string. return the ast.Name instance of the formatted string.
""" """
@ -731,7 +750,8 @@ class AssertionRewriter(ast.NodeVisitor):
format_dict = ast.Dict(keys, list(current.values())) format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict) form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter)) name = "@py_format" + str(next(self.variable_counter))
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form)) self.format_variables.append(name)
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load()) return ast.Name(name, ast.Load())
def generic_visit(self, node): def generic_visit(self, node):
@ -765,8 +785,9 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = [] self.statements = []
self.variables = [] self.variables = []
self.variable_counter = itertools.count() self.variable_counter = itertools.count()
self.format_variables = []
self.stack = [] self.stack = []
self.on_failure = [] self.expl_stmts = []
self.push_format_context() self.push_format_context()
# Rewrite assert into a bunch of statements. # Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test) top_condition, explanation = self.visit(assert_.test)
@ -777,24 +798,46 @@ class AssertionRewriter(ast.NodeVisitor):
top_condition, module_path=self.module_path, lineno=assert_.lineno top_condition, module_path=self.module_path, lineno=assert_.lineno
) )
) )
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition) negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, [])) msg = self.pop_format_context(ast.Str(explanation))
if assert_.msg: if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg) assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation gluestr = "\n>assert "
else: else:
assertmsg = ast.Str("") assertmsg = ast.Str("")
explanation = "assert " + explanation gluestr = "assert "
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
msg = self.pop_format_context(template) err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load()) err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
fmt_pass = self.helper("_format_explanation", msg)
exc = ast.Call(err_name, [fmt], []) exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None) if sys.version_info[0] >= 3:
raise_ = ast.Raise(exc, None)
else:
raise_ = ast.Raise(exc, None, None)
# Call to hook when passes
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass
)
)
body.append(raise_) # If any hooks implement assert_pass hook
hook_impl_test = ast.If(
self.helper("_check_if_assertionpass_impl"),
[hook_call_pass],
[],
)
main_test = ast.If(negation, [raise_], [hook_impl_test])
self.statements.extend(self.expl_stmts)
self.statements.append(main_test)
if self.format_variables:
variables = [ast.Name(name, ast.Store()) for name in self.format_variables]
clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format)
# Clear temporary variables by setting them to None. # Clear temporary variables by setting them to None.
if self.variables: if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables] variables = [ast.Name(name, ast.Store()) for name in self.variables]
@ -848,7 +891,7 @@ warn_explicit(
app = ast.Attribute(expl_list, "append", ast.Load()) app = ast.Attribute(expl_list, "append", ast.Load())
is_or = int(isinstance(boolop.op, ast.Or)) is_or = int(isinstance(boolop.op, ast.Or))
body = save = self.statements body = save = self.statements
fail_save = self.on_failure fail_save = self.expl_stmts
levels = len(boolop.values) - 1 levels = len(boolop.values) - 1
self.push_format_context() self.push_format_context()
# Process each operand, short-circuting if needed. # Process each operand, short-circuting if needed.
@ -856,14 +899,14 @@ warn_explicit(
if i: if i:
fail_inner = [] fail_inner = []
# cond is set in a prior loop iteration below # cond is set in a prior loop iteration below
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.on_failure = fail_inner self.expl_stmts = fail_inner
self.push_format_context() self.push_format_context()
res, expl = self.visit(v) res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Str(expl)) expl_format = self.pop_format_context(ast.Str(expl))
call = ast.Call(app, [expl_format], []) call = ast.Call(app, [expl_format], [])
self.on_failure.append(ast.Expr(call)) self.expl_stmts.append(ast.Expr(call))
if i < levels: if i < levels:
cond = res cond = res
if is_or: if is_or:
@ -872,7 +915,7 @@ warn_explicit(
self.statements.append(ast.If(cond, inner, [])) self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner self.statements = body = inner
self.statements = save self.statements = save
self.on_failure = fail_save self.expl_stmts = fail_save
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template) expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl) return ast.Name(res_var, ast.Load()), self.explanation_param(expl)

View File

@ -12,6 +12,10 @@ from _pytest._io.saferepr import saferepr
# DebugInterpreter. # DebugInterpreter.
_reprcompare = None _reprcompare = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass = None
def format_explanation(explanation): def format_explanation(explanation):
"""This formats an explanation """This formats an explanation

View File

@ -485,6 +485,21 @@ def pytest_assertrepr_compare(config, op, left, right):
""" """
def pytest_assertion_pass(item, lineno, orig, expl):
"""Process explanation when assertions are valid.
Use this hook to do some processing after a passing assertion.
The original assertion information is available in the `orig` string
and the pytest introspected assertion information is available in the
`expl` string.
:param _pytest.nodes.Item item: pytest item object of current test
:param int lineno: line number of the assert statement
:param string orig: string with original assertion
:param string expl: string with assert explanation
"""
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# hooks for influencing reporting (invoked from _pytest_terminal) # hooks for influencing reporting (invoked from _pytest_terminal)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------

View File

@ -1047,51 +1047,6 @@ def test_deferred_hook_checking(testdir):
result.stdout.fnmatch_lines(["* 1 passed *"]) result.stdout.fnmatch_lines(["* 1 passed *"])
def test_fixture_values_leak(testdir):
"""Ensure that fixture objects are properly destroyed by the garbage collector at the end of their expected
life-times (#2981).
"""
testdir.makepyfile(
"""
import attr
import gc
import pytest
import weakref
@attr.s
class SomeObj(object):
name = attr.ib()
fix_of_test1_ref = None
session_ref = None
@pytest.fixture(scope='session')
def session_fix():
global session_ref
obj = SomeObj(name='session-fixture')
session_ref = weakref.ref(obj)
return obj
@pytest.fixture
def fix(session_fix):
global fix_of_test1_ref
obj = SomeObj(name='local-fixture')
fix_of_test1_ref = weakref.ref(obj)
return obj
def test1(fix):
assert fix_of_test1_ref() is fix
def test2():
gc.collect()
# fixture "fix" created during test1 must have been destroyed by now
assert fix_of_test1_ref() is None
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 2 passed *"])
def test_fixture_order_respects_scope(testdir): def test_fixture_order_respects_scope(testdir):
"""Ensure that fixtures are created according to scope order, regression test for #2405 """Ensure that fixtures are created according to scope order, regression test for #2405
""" """

View File

@ -0,0 +1,53 @@
"""Ensure that fixture objects are properly destroyed by the garbage collector at the end of their expected
life-times (#2981).
This comes from the old acceptance_test.py::test_fixture_values_leak(testdir):
This used pytester before but was not working when using pytest_assert_reprcompare
because pytester tracks hook calls and it would hold a reference (ParsedCall object),
preventing garbage collection
<ParsedCall 'pytest_assertrepr_compare'(**{
'config': <_pytest.config.Config object at 0x0000019C18D1C2B0>,
'op': 'is',
'left': SomeObj(name='local-fixture'),
'right': SomeObj(name='local-fixture')})>
"""
import attr
import gc
import pytest
import weakref
@attr.s
class SomeObj(object):
name = attr.ib()
fix_of_test1_ref = None
session_ref = None
@pytest.fixture(scope="session")
def session_fix():
global session_ref
obj = SomeObj(name="session-fixture")
session_ref = weakref.ref(obj)
return obj
@pytest.fixture
def fix(session_fix):
global fix_of_test1_ref
obj = SomeObj(name="local-fixture")
fix_of_test1_ref = weakref.ref(obj)
return obj
def test1(fix):
assert fix_of_test1_ref() is fix
def test2():
gc.collect()
# fixture "fix" created during test1 must have been destroyed by now
assert fix_of_test1_ref() is None

View File

@ -202,6 +202,9 @@ class TestRaises:
assert sys.exc_info() == (None, None, None) assert sys.exc_info() == (None, None, None)
del t del t
# Make sure this does get updated in locals dict
# otherwise it could keep a reference
locals()
# ensure the t instance is not stuck in a cyclic reference # ensure the t instance is not stuck in a cyclic reference
for o in gc.get_objects(): for o in gc.get_objects():

View File

@ -1305,3 +1305,54 @@ class TestEarlyRewriteBailout:
) )
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 1 passed in *"]) result.stdout.fnmatch_lines(["* 1 passed in *"])
class TestAssertionPass:
def test_hook_call(self, testdir):
testdir.makeconftest(
"""
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)
testdir.makepyfile(
"""
def test_simple():
a=1
b=2
c=3
d=0
assert a+b == c+d
"""
)
result = testdir.runpytest()
print(testdir.tmpdir)
result.stdout.fnmatch_lines(
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
)
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
"""Assertion pass should not be called (and hence formatting should
not occur) if there is no hook declared for pytest_assertion_pass"""
def raise_on_assertionpass(*_, **__):
raise Exception("Assertion passed called when it shouldn't!")
monkeypatch.setattr(_pytest.assertion.rewrite,
"_call_assertion_pass",
raise_on_assertionpass)
testdir.makepyfile(
"""
def test_simple():
a=1
b=2
c=3
d=0
assert a+b == c+d
"""
)
result = testdir.runpytest()
result.assert_outcomes(passed=1)