Assertion passed hook

This commit is contained in:
Victor Maryama
2019-06-24 16:09:39 +02:00
parent 3d01dd3adf
commit 9a89783fbb
10 changed files with 202 additions and 69 deletions

View File

@@ -92,7 +92,7 @@ def pytest_collection(session):
def pytest_runtest_setup(item):
"""Setup the pytest_assertrepr_compare hook
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
The newinterpret and rewrite modules will use util._reprcompare if
it exists to use custom reporting via the
@@ -129,9 +129,15 @@ def pytest_runtest_setup(item):
util._reprcompare = callbinrepr
if item.ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno, expl, orig):
item.ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
util._assertion_pass = call_assertion_pass_hook
def pytest_runtest_teardown(item):
util._reprcompare = None
util._assertion_pass = None
def pytest_sessionfinish(session):

View File

@@ -1,5 +1,6 @@
"""Rewrite assertion AST to produce nice error messages"""
import ast
import astor
import errno
import imp
import itertools
@@ -357,6 +358,11 @@ def _rewrite_test(config, fn):
state.trace("failed to parse: {!r}".format(fn))
return None, None
rewrite_asserts(tree, fn, config)
# TODO: REMOVE, THIS IS ONLY FOR DEBUG
with open(f'{str(fn)+"bak"}', "w", encoding="utf-8") as f:
f.write(astor.to_source(tree))
try:
co = compile(tree, fn.strpath, "exec", dont_inherit=True)
except SyntaxError:
@@ -434,7 +440,7 @@ def _format_assertmsg(obj):
# contains a newline it gets escaped, however if an object has a
# .__repr__() which contains newlines it does not get escaped.
# However in either case we want to preserve the newline.
replaces = [("\n", "\n~"), ("%", "%%")]
replaces = [("\n", "\n~")]
if not isinstance(obj, str):
obj = saferepr(obj)
replaces.append(("\\n", "\n~"))
@@ -478,6 +484,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
return expl
def _call_assertion_pass(lineno, orig, expl):
if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
def _check_if_assertionpass_impl():
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False
unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
binop_map = {
@@ -550,7 +567,8 @@ class AssertionRewriter(ast.NodeVisitor):
original assert statement: it rewrites the test of an assertion
to provide intermediate values and replace it with an if statement
which raises an assertion error with a detailed explanation in
case the expression is false.
case the expression is false and calls pytest_assertion_pass hook
if expression is true.
For this .visit_Assert() uses the visitor pattern to visit all the
AST nodes of the ast.Assert.test field, each visit call returning
@@ -568,9 +586,10 @@ class AssertionRewriter(ast.NodeVisitor):
by statements. Variables are created using .variable() and
have the form of "@py_assert0".
:on_failure: The AST statements which will be executed if the
assertion test fails. This is the code which will construct
the failure message and raises the AssertionError.
:expl_stmts: The AST statements which will be executed to get
data from the assertion. This is the code which will construct
the detailed assertion message that is used in the AssertionError
or for the pytest_assertion_pass hook.
:explanation_specifiers: A dict filled by .explanation_param()
with %-formatting placeholders and their corresponding
@@ -720,7 +739,7 @@ class AssertionRewriter(ast.NodeVisitor):
The expl_expr should be an ast.Str instance constructed from
the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .on_failure and
add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string.
"""
@@ -731,7 +750,8 @@ class AssertionRewriter(ast.NodeVisitor):
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
self.format_variables.append(name)
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())
def generic_visit(self, node):
@@ -765,8 +785,9 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = []
self.variables = []
self.variable_counter = itertools.count()
self.format_variables = []
self.stack = []
self.on_failure = []
self.expl_stmts = []
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
@@ -777,24 +798,46 @@ class AssertionRewriter(ast.NodeVisitor):
top_condition, module_path=self.module_path, lineno=assert_.lineno
)
)
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
msg = self.pop_format_context(ast.Str(explanation))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
gluestr = "\n>assert "
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
gluestr = "assert "
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
fmt_pass = self.helper("_format_explanation", msg)
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
if sys.version_info[0] >= 3:
raise_ = ast.Raise(exc, None)
else:
raise_ = ast.Raise(exc, None, None)
# Call to hook when passes
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass
)
)
body.append(raise_)
# If any hooks implement assert_pass hook
hook_impl_test = ast.If(
self.helper("_check_if_assertionpass_impl"),
[hook_call_pass],
[],
)
main_test = ast.If(negation, [raise_], [hook_impl_test])
self.statements.extend(self.expl_stmts)
self.statements.append(main_test)
if self.format_variables:
variables = [ast.Name(name, ast.Store()) for name in self.format_variables]
clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format)
# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
@@ -848,7 +891,7 @@ warn_explicit(
app = ast.Attribute(expl_list, "append", ast.Load())
is_or = int(isinstance(boolop.op, ast.Or))
body = save = self.statements
fail_save = self.on_failure
fail_save = self.expl_stmts
levels = len(boolop.values) - 1
self.push_format_context()
# Process each operand, short-circuting if needed.
@@ -856,14 +899,14 @@ warn_explicit(
if i:
fail_inner = []
# cond is set in a prior loop iteration below
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
self.on_failure = fail_inner
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Str(expl))
call = ast.Call(app, [expl_format], [])
self.on_failure.append(ast.Expr(call))
self.expl_stmts.append(ast.Expr(call))
if i < levels:
cond = res
if is_or:
@@ -872,7 +915,7 @@ warn_explicit(
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
self.statements = save
self.on_failure = fail_save
self.expl_stmts = fail_save
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)

View File

@@ -12,6 +12,10 @@ from _pytest._io.saferepr import saferepr
# DebugInterpreter.
_reprcompare = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass = None
def format_explanation(explanation):
"""This formats an explanation

View File

@@ -485,6 +485,21 @@ def pytest_assertrepr_compare(config, op, left, right):
"""
def pytest_assertion_pass(item, lineno, orig, expl):
"""Process explanation when assertions are valid.
Use this hook to do some processing after a passing assertion.
The original assertion information is available in the `orig` string
and the pytest introspected assertion information is available in the
`expl` string.
:param _pytest.nodes.Item item: pytest item object of current test
:param int lineno: line number of the assert statement
:param string orig: string with original assertion
:param string expl: string with assert explanation
"""
# -------------------------------------------------------------------------
# hooks for influencing reporting (invoked from _pytest_terminal)
# -------------------------------------------------------------------------