Now dependent on command line option.

This commit is contained in:
Victor Maryama
2019-06-25 19:49:05 +02:00
parent cfbfa53f2b
commit 4db5488ed8
5 changed files with 112 additions and 38 deletions

View File

@@ -24,6 +24,14 @@ def pytest_addoption(parser):
expression information.""",
)
group = parser.getgroup("experimental")
group.addoption(
"--enable-assertion-pass-hook",
action="store_true",
help="Enables the pytest_assertion_pass hook."
"Make sure to delete any previously generated pyc cache files.",
)
def register_assert_rewrite(*names):
"""Register one or more module names to be rewritten on import.

View File

@@ -745,7 +745,8 @@ class AssertionRewriter(ast.NodeVisitor):
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
self.format_variables.append(name)
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
self.format_variables.append(name)
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())
@@ -780,7 +781,10 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = []
self.variables = []
self.variable_counter = itertools.count()
self.format_variables = []
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
self.format_variables = []
self.stack = []
self.expl_stmts = []
self.push_format_context()
@@ -793,41 +797,68 @@ class AssertionRewriter(ast.NodeVisitor):
top_condition, module_path=self.module_path, lineno=assert_.lineno
)
)
negation = ast.UnaryOp(ast.Not(), top_condition)
msg = self.pop_format_context(ast.Str(explanation))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
gluestr = "\n>assert "
else:
assertmsg = ast.Str("")
gluestr = "assert "
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
fmt_pass = self.helper("_format_explanation", msg)
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
# Call to hook when passes
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass", ast.Num(assert_.lineno), ast.Str(orig), fmt_pass
if getattr(self.config._ns, "enable_assertion_pass_hook", False):
### Experimental pytest_assertion_pass hook
negation = ast.UnaryOp(ast.Not(), top_condition)
msg = self.pop_format_context(ast.Str(explanation))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
gluestr = "\n>assert "
else:
assertmsg = ast.Str("")
gluestr = "assert "
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
fmt_pass = self.helper("_format_explanation", msg)
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
# Call to hook when passes
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
ast.Num(assert_.lineno),
ast.Str(orig),
fmt_pass,
)
)
)
# If any hooks implement assert_pass hook
hook_impl_test = ast.If(
self.helper("_check_if_assertionpass_impl"), [hook_call_pass], []
)
main_test = ast.If(negation, [raise_], [hook_impl_test])
# If any hooks implement assert_pass hook
hook_impl_test = ast.If(
self.helper("_check_if_assertionpass_impl"), [hook_call_pass], []
)
main_test = ast.If(negation, [raise_], [hook_impl_test])
self.statements.extend(self.expl_stmts)
self.statements.append(main_test)
if self.format_variables:
variables = [ast.Name(name, ast.Store()) for name in self.format_variables]
clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format)
self.statements.extend(self.expl_stmts)
self.statements.append(main_test)
if self.format_variables:
variables = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format)
else:
### Original assertion rewriting
# Create failure message.
body = self.expl_stmts
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
body.append(raise_)
# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]