|
|
|
|
@@ -1,5 +1,6 @@
|
|
|
|
|
"""Rewrite assertion AST to produce nice error messages"""
|
|
|
|
|
import ast
|
|
|
|
|
import astor
|
|
|
|
|
import errno
|
|
|
|
|
import imp
|
|
|
|
|
import itertools
|
|
|
|
|
@@ -357,6 +358,11 @@ def _rewrite_test(config, fn):
|
|
|
|
|
state.trace("failed to parse: {!r}".format(fn))
|
|
|
|
|
return None, None
|
|
|
|
|
rewrite_asserts(tree, fn, config)
|
|
|
|
|
|
|
|
|
|
# TODO: REMOVE, THIS IS ONLY FOR DEBUG
|
|
|
|
|
with open(f'{str(fn)+"bak"}', "w", encoding="utf-8") as f:
|
|
|
|
|
f.write(astor.to_source(tree))
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
co = compile(tree, fn.strpath, "exec", dont_inherit=True)
|
|
|
|
|
except SyntaxError:
|
|
|
|
|
@@ -434,7 +440,7 @@ def _format_assertmsg(obj):
|
|
|
|
|
# contains a newline it gets escaped, however if an object has a
|
|
|
|
|
# .__repr__() which contains newlines it does not get escaped.
|
|
|
|
|
# However in either case we want to preserve the newline.
|
|
|
|
|
replaces = [("\n", "\n~"), ("%", "%%")]
|
|
|
|
|
replaces = [("\n", "\n~")]
|
|
|
|
|
if not isinstance(obj, str):
|
|
|
|
|
obj = saferepr(obj)
|
|
|
|
|
replaces.append(("\\n", "\n~"))
|
|
|
|
|
@@ -478,6 +484,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
|
|
|
|
|
return expl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _call_assertion_pass(lineno, orig, expl):
|
|
|
|
|
if util._assertion_pass is not None:
|
|
|
|
|
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_if_assertionpass_impl():
|
|
|
|
|
"""Checks if any plugins implement the pytest_assertion_pass hook
|
|
|
|
|
in order not to generate explanation unecessarily (might be expensive)"""
|
|
|
|
|
return True if util._assertion_pass else False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
|
|
|
|
|
|
|
|
|
|
binop_map = {
|
|
|
|
|
@@ -550,7 +567,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
|
|
|
original assert statement: it rewrites the test of an assertion
|
|
|
|
|
to provide intermediate values and replace it with an if statement
|
|
|
|
|
which raises an assertion error with a detailed explanation in
|
|
|
|
|
case the expression is false.
|
|
|
|
|
case the expression is false and calls pytest_assertion_pass hook
|
|
|
|
|
if expression is true.
|
|
|
|
|
|
|
|
|
|
For this .visit_Assert() uses the visitor pattern to visit all the
|
|
|
|
|
AST nodes of the ast.Assert.test field, each visit call returning
|
|
|
|
|
@@ -568,9 +586,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
|
|
|
by statements. Variables are created using .variable() and
|
|
|
|
|
have the form of "@py_assert0".
|
|
|
|
|
|
|
|
|
|
:on_failure: The AST statements which will be executed if the
|
|
|
|
|
assertion test fails. This is the code which will construct
|
|
|
|
|
the failure message and raises the AssertionError.
|
|
|
|
|
:expl_stmts: The AST statements which will be executed to get
|
|
|
|
|
data from the assertion. This is the code which will construct
|
|
|
|
|
the detailed assertion message that is used in the AssertionError
|
|
|
|
|
or for the pytest_assertion_pass hook.
|
|
|
|
|
|
|
|
|
|
:explanation_specifiers: A dict filled by .explanation_param()
|
|
|
|
|
with %-formatting placeholders and their corresponding
|
|
|
|
|
@@ -720,7 +739,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
|
|
|
|
|
|
|
|
The expl_expr should be an ast.Str instance constructed from
|
|
|
|
|
the %-placeholders created by .explanation_param(). This will
|
|
|
|
|
add the required code to format said string to .on_failure and
|
|
|
|
|
add the required code to format said string to .expl_stmts and
|
|
|
|
|
return the ast.Name instance of the formatted string.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
@@ -731,7 +750,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
|
|
|
format_dict = ast.Dict(keys, list(current.values()))
|
|
|
|
|
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
|
|
|
|
name = "@py_format" + str(next(self.variable_counter))
|
|
|
|
|
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
|
|
|
|
self.format_variables.append(name)
|
|
|
|
|
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
|
|
|
|
return ast.Name(name, ast.Load())
|
|
|
|
|
|
|
|
|
|
def generic_visit(self, node):
|
|
|
|
|
@@ -765,8 +785,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
|
|
|
self.statements = []
|
|
|
|
|
self.variables = []
|
|
|
|
|
self.variable_counter = itertools.count()
|
|
|
|
|
self.format_variables = []
|
|
|
|
|
self.stack = []
|
|
|
|
|
self.on_failure = []
|
|
|
|
|
self.expl_stmts = []
|
|
|
|
|
self.push_format_context()
|
|
|
|
|
# Rewrite assert into a bunch of statements.
|
|
|
|
|
top_condition, explanation = self.visit(assert_.test)
|
|
|
|
|
@@ -777,24 +798,46 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
|
|
|
top_condition, module_path=self.module_path, lineno=assert_.lineno
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
# Create failure message.
|
|
|
|
|
body = self.on_failure
|
|
|
|
|
negation = ast.UnaryOp(ast.Not(), top_condition)
|
|
|
|
|
self.statements.append(ast.If(negation, body, []))
|
|
|
|
|
msg = self.pop_format_context(ast.Str(explanation))
|
|
|
|
|
if assert_.msg:
|
|
|
|
|
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
|
|
|
|
explanation = "\n>assert " + explanation
|
|
|
|
|
gluestr = "\n>assert "
|
|
|
|
|
else:
|
|
|
|
|
assertmsg = ast.Str("")
|
|
|
|
|
explanation = "assert " + explanation
|
|
|
|
|
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
|
|
|
|
|
msg = self.pop_format_context(template)
|
|
|
|
|
fmt = self.helper("_format_explanation", msg)
|
|
|
|
|
gluestr = "assert "
|
|
|
|
|
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
|
|
|
|
|
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
|
|
|
|
|
err_name = ast.Name("AssertionError", ast.Load())
|
|
|
|
|
fmt = self.helper("_format_explanation", err_msg)
|
|
|
|
|
fmt_pass = self.helper("_format_explanation", msg)
|
|
|
|
|
exc = ast.Call(err_name, [fmt], [])
|
|
|
|
|
raise_ = ast.Raise(exc, None)
|
|
|
|
|
if sys.version_info[0] >= 3:
|
|
|
|
|
raise_ = ast.Raise(exc, None)
|
|
|
|
|
else:
|
|
|
|
|
raise_ = ast.Raise(exc, None, None)
|
|
|
|
|
# Call to hook when passes
|
|
|
|
|
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
|
|
|
|
|
hook_call_pass = ast.Expr(
|
|
|
|
|
self.helper(
|
|
|
|
|
"_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
body.append(raise_)
|
|
|
|
|
# If any hooks implement assert_pass hook
|
|
|
|
|
hook_impl_test = ast.If(
|
|
|
|
|
self.helper("_check_if_assertionpass_impl"),
|
|
|
|
|
[hook_call_pass],
|
|
|
|
|
[],
|
|
|
|
|
)
|
|
|
|
|
main_test = ast.If(negation, [raise_], [hook_impl_test])
|
|
|
|
|
|
|
|
|
|
self.statements.extend(self.expl_stmts)
|
|
|
|
|
self.statements.append(main_test)
|
|
|
|
|
if self.format_variables:
|
|
|
|
|
variables = [ast.Name(name, ast.Store()) for name in self.format_variables]
|
|
|
|
|
clear_format = ast.Assign(variables, _NameConstant(None))
|
|
|
|
|
self.statements.append(clear_format)
|
|
|
|
|
# Clear temporary variables by setting them to None.
|
|
|
|
|
if self.variables:
|
|
|
|
|
variables = [ast.Name(name, ast.Store()) for name in self.variables]
|
|
|
|
|
@@ -848,7 +891,7 @@ warn_explicit(
|
|
|
|
|
app = ast.Attribute(expl_list, "append", ast.Load())
|
|
|
|
|
is_or = int(isinstance(boolop.op, ast.Or))
|
|
|
|
|
body = save = self.statements
|
|
|
|
|
fail_save = self.on_failure
|
|
|
|
|
fail_save = self.expl_stmts
|
|
|
|
|
levels = len(boolop.values) - 1
|
|
|
|
|
self.push_format_context()
|
|
|
|
|
# Process each operand, short-circuting if needed.
|
|
|
|
|
@@ -856,14 +899,14 @@ warn_explicit(
|
|
|
|
|
if i:
|
|
|
|
|
fail_inner = []
|
|
|
|
|
# cond is set in a prior loop iteration below
|
|
|
|
|
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
|
|
|
|
|
self.on_failure = fail_inner
|
|
|
|
|
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
|
|
|
|
self.expl_stmts = fail_inner
|
|
|
|
|
self.push_format_context()
|
|
|
|
|
res, expl = self.visit(v)
|
|
|
|
|
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
|
|
|
|
expl_format = self.pop_format_context(ast.Str(expl))
|
|
|
|
|
call = ast.Call(app, [expl_format], [])
|
|
|
|
|
self.on_failure.append(ast.Expr(call))
|
|
|
|
|
self.expl_stmts.append(ast.Expr(call))
|
|
|
|
|
if i < levels:
|
|
|
|
|
cond = res
|
|
|
|
|
if is_or:
|
|
|
|
|
@@ -872,7 +915,7 @@ warn_explicit(
|
|
|
|
|
self.statements.append(ast.If(cond, inner, []))
|
|
|
|
|
self.statements = body = inner
|
|
|
|
|
self.statements = save
|
|
|
|
|
self.on_failure = fail_save
|
|
|
|
|
self.expl_stmts = fail_save
|
|
|
|
|
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
|
|
|
|
|
expl = self.pop_format_context(expl_template)
|
|
|
|
|
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
|
|
|
|
|