diff --git a/CHANGELOG b/CHANGELOG index 8fe13a643..4dfccdb37 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -9,6 +9,10 @@ NEXT other builds due to the extra argparse dependency. Fixes issue566. Thanks sontek. +- Implement issue549: user-provided assertion messages now no longer + replace the py.test instrospection message but are shown in addition + to them. + 2.6.1 ----------------------------------- diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 523d2b2dc..d552915ca 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -329,6 +329,33 @@ def rewrite_asserts(mod): _saferepr = py.io.saferepr from _pytest.assertion.util import format_explanation as _format_explanation # noqa +def _format_assertmsg(obj): + """Format the custom assertion message given. + + For strings this simply replaces newlines with '\n~' so that + util.format_explanation() will preserve them instead of escaping + newlines. For other objects py.io.saferepr() is used first. + + """ + # reprlib appears to have a bug which means that if a string + # contains a newline it gets escaped, however if an object has a + # .__repr__() which contains newlines it does not get escaped. + # However in either case we want to preserve the newline. + if py.builtin._istext(obj) or py.builtin._isbytes(obj): + s = obj + is_repr = False + else: + s = py.io.saferepr(obj) + is_repr = True + if py.builtin._istext(s): + t = py.builtin.text + else: + t = py.builtin.bytes + s = s.replace(t("\n"), t("\n~")) + if is_repr: + s = s.replace(t("\\n"), t("\n~")) + return s + def _should_repr_global_name(obj): return not hasattr(obj, "__name__") and not py.builtin.callable(obj) @@ -397,6 +424,56 @@ def set_location(node, lineno, col_offset): class AssertionRewriter(ast.NodeVisitor): + """Assertion rewriting implementation. + + The main entrypoint is to call .run() with an ast.Module instance, + this will then find all the assert statements and re-write them to + provide intermediate values and a detailed assertion error. See + http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html + for an overview of how this works. + + The entry point here is .run() which will iterate over all the + statenemts in an ast.Module and for each ast.Assert statement it + finds call .visit() with it. Then .visit_Assert() takes over and + is responsible for creating new ast statements to replace the + original assert statement: it re-writes the test of an assertion + to provide intermediate values and replace it with an if statement + which raises an assertion error with a detailed explanation in + case the expression is false. + + For this .visit_Assert() uses the visitor pattern to visit all the + AST nodes of the ast.Assert.test field, each visit call returning + an AST node and the corresponding explanation string. During this + state is kept in several instance attributes: + + :statements: All the AST statements which will replace the assert + statement. + + :variables: This is populated by .variable() with each variable + used by the statements so that they can all be set to None at + the end of the statements. + + :variable_counter: Counter to create new unique variables needed + by statements. Variables are created using .variable() and + have the form of "@py_assert0". + + :on_failure: The AST statements which will be executed if the + assertion test fails. This is the code which will construct + the failure message and raises the AssertionError. + + :explanation_specifiers: A dict filled by .explanation_param() + with %-formatting placeholders and their corresponding + expressions to use in the building of an assertion message. + This is used by .pop_format_context() to build a message. + + :stack: A stack of the explanation_specifiers dicts maintained by + .push_format_context() and .pop_format_context() which allows + to build another %-formatted string while already building one. + + This state is reset on every new assert statement visited and used + by the other visitors. + + """ def run(self, mod): """Find all assert statements in *mod* and rewrite them.""" @@ -478,15 +555,41 @@ class AssertionRewriter(ast.NodeVisitor): return ast.Attribute(builtin_name, name, ast.Load()) def explanation_param(self, expr): + """Return a new named %-formatting placeholder for expr. + + This creates a %-formatting placeholder for expr in the + current formatting context, e.g. ``%(py0)s``. The placeholder + and expr are placed in the current format context so that it + can be used on the next call to .pop_format_context(). + + """ specifier = "py" + str(next(self.variable_counter)) self.explanation_specifiers[specifier] = expr return "%(" + specifier + ")s" def push_format_context(self): + """Create a new formatting context. + + The format context is used for when an explanation wants to + have a variable value formatted in the assertion message. In + this case the value required can be added using + .explanation_param(). Finally .pop_format_context() is used + to format a string of %-formatted values as added by + .explanation_param(). + + """ self.explanation_specifiers = {} self.stack.append(self.explanation_specifiers) def pop_format_context(self, expl_expr): + """Format the %-formatted string with current format context. + + The expl_expr should be an ast.Str instance constructed from + the %-placeholders created by .explanation_param(). This will + add the required code to format said string to .on_failure and + return the ast.Name instance of the formatted string. + + """ current = self.stack.pop() if self.stack: self.explanation_specifiers = self.stack[-1] @@ -504,11 +607,15 @@ class AssertionRewriter(ast.NodeVisitor): return res, self.explanation_param(self.display(res)) def visit_Assert(self, assert_): - if assert_.msg: - # There's already a message. Don't mess with it. - return [assert_] + """Return the AST statements to replace the ast.Assert instance. + + This re-writes the test of an assertion to provide + intermediate values and replace it with an if statement which + raises an assertion error with a detailed explanation in case + the expression is false. + + """ self.statements = [] - self.cond_chain = () self.variables = [] self.variable_counter = itertools.count() self.stack = [] @@ -520,8 +627,13 @@ class AssertionRewriter(ast.NodeVisitor): body = self.on_failure negation = ast.UnaryOp(ast.Not(), top_condition) self.statements.append(ast.If(negation, body, [])) - explanation = "assert " + explanation - template = ast.Str(explanation) + if assert_.msg: + assertmsg = self.helper('format_assertmsg', assert_.msg) + explanation = "\n>assert " + explanation + else: + assertmsg = ast.Str("") + explanation = "assert " + explanation + template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) msg = self.pop_format_context(template) fmt = self.helper("format_explanation", msg) err_name = ast.Name("AssertionError", ast.Load()) diff --git a/_pytest/assertion/util.py b/_pytest/assertion/util.py index 95d308518..f3810fec6 100644 --- a/_pytest/assertion/util.py +++ b/_pytest/assertion/util.py @@ -73,7 +73,7 @@ def _split_explanation(explanation): raw_lines = (explanation or u('')).split('\n') lines = [raw_lines[0]] for l in raw_lines[1:]: - if l.startswith('{') or l.startswith('}') or l.startswith('~'): + if l and l[0] in ['{', '}', '~', '>']: lines.append(l) else: lines[-1] += '\\n' + l @@ -103,13 +103,14 @@ def _format_lines(lines): stackcnt.append(0) result.append(u(' +') + u(' ')*(len(stack)-1) + s + line[1:]) elif line.startswith('}'): - assert line.startswith('}') stack.pop() stackcnt.pop() result[stack[-1]] += line[1:] else: - assert line.startswith('~') - result.append(u(' ')*len(stack) + line[1:]) + assert line[0] in ['~', '>'] + stack[-1] += 1 + indent = len(stack) if line.startswith('~') else len(stack) - 1 + result.append(u(' ')*indent + line[1:]) assert len(stack) == 1 return result diff --git a/doc/en/example/assertion/failure_demo.py b/doc/en/example/assertion/failure_demo.py index ed2776ee1..ecc1cd356 100644 --- a/doc/en/example/assertion/failure_demo.py +++ b/doc/en/example/assertion/failure_demo.py @@ -211,3 +211,27 @@ class TestMoreErrors: finally: x = 0 + +class TestCustomAssertMsg: + + def test_single_line(self): + class A: + a = 1 + b = 2 + assert A.a == b, "A.a appears not to be b" + + def test_multiline(self): + class A: + a = 1 + b = 2 + assert A.a == b, "A.a appears not to be b\n" \ + "or does not appear to be b\none of those" + + def test_custom_repr(self): + class JSON: + a = 1 + def __repr__(self): + return "This is JSON\n{\n 'foo': 'bar'\n}" + a = JSON() + b = 2 + assert a.a == b, a diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 5b873cdc2..3cad8bb60 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -4,6 +4,7 @@ import sys import py, pytest import _pytest.assertion as plugin from _pytest.assertion import reinterpret +from _pytest.assertion import util needsnewassert = pytest.mark.skipif("sys.version_info < (2,6)") @@ -201,7 +202,7 @@ class TestAssert_reprcompare: class TestFormatExplanation: - def test_speical_chars_full(self, testdir): + def test_special_chars_full(self, testdir): # Issue 453, for the bug this would raise IndexError testdir.makepyfile(""" def test_foo(): @@ -213,6 +214,83 @@ class TestFormatExplanation: "*AssertionError*", ]) + def test_fmt_simple(self): + expl = 'assert foo' + assert util.format_explanation(expl) == 'assert foo' + + def test_fmt_where(self): + expl = '\n'.join(['assert 1', + '{1 = foo', + '} == 2']) + res = '\n'.join(['assert 1 == 2', + ' + where 1 = foo']) + assert util.format_explanation(expl) == res + + def test_fmt_and(self): + expl = '\n'.join(['assert 1', + '{1 = foo', + '} == 2', + '{2 = bar', + '}']) + res = '\n'.join(['assert 1 == 2', + ' + where 1 = foo', + ' + and 2 = bar']) + assert util.format_explanation(expl) == res + + def test_fmt_where_nested(self): + expl = '\n'.join(['assert 1', + '{1 = foo', + '{foo = bar', + '}', + '} == 2']) + res = '\n'.join(['assert 1 == 2', + ' + where 1 = foo', + ' + where foo = bar']) + assert util.format_explanation(expl) == res + + def test_fmt_newline(self): + expl = '\n'.join(['assert "foo" == "bar"', + '~- foo', + '~+ bar']) + res = '\n'.join(['assert "foo" == "bar"', + ' - foo', + ' + bar']) + assert util.format_explanation(expl) == res + + def test_fmt_newline_escaped(self): + expl = '\n'.join(['assert foo == bar', + 'baz']) + res = 'assert foo == bar\\nbaz' + assert util.format_explanation(expl) == res + + def test_fmt_newline_before_where(self): + expl = '\n'.join(['the assertion message here', + '>assert 1', + '{1 = foo', + '} == 2', + '{2 = bar', + '}']) + res = '\n'.join(['the assertion message here', + 'assert 1 == 2', + ' + where 1 = foo', + ' + and 2 = bar']) + assert util.format_explanation(expl) == res + + def test_fmt_multi_newline_before_where(self): + expl = '\n'.join(['the assertion', + '~message here', + '>assert 1', + '{1 = foo', + '} == 2', + '{2 = bar', + '}']) + res = '\n'.join(['the assertion', + ' message here', + 'assert 1 == 2', + ' + where 1 = foo', + ' + and 2 = bar']) + assert util.format_explanation(expl) == res + def test_python25_compile_issue257(testdir): testdir.makepyfile(""" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 72e2d4e10..6c93bdb31 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -121,7 +121,56 @@ class TestAssertionRewrite: def test_assert_already_has_message(self): def f(): assert False, "something bad!" - assert getmsg(f) == "AssertionError: something bad!" + assert getmsg(f) == "AssertionError: something bad!\nassert False" + + def test_assertion_message(self, testdir): + testdir.makepyfile(""" + def test_foo(): + assert 1 == 2, "The failure message" + """) + result = testdir.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines([ + "*AssertionError*The failure message*", + "*assert 1 == 2*", + ]) + + def test_assertion_message_multiline(self, testdir): + testdir.makepyfile(""" + def test_foo(): + assert 1 == 2, "A multiline\\nfailure message" + """) + result = testdir.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines([ + "*AssertionError*A multiline*", + "*failure message*", + "*assert 1 == 2*", + ]) + + def test_assertion_message_tuple(self, testdir): + testdir.makepyfile(""" + def test_foo(): + assert 1 == 2, (1, 2) + """) + result = testdir.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines([ + "*AssertionError*%s*" % repr((1, 2)), + "*assert 1 == 2*", + ]) + + def test_assertion_message_expr(self, testdir): + testdir.makepyfile(""" + def test_foo(): + assert 1 == 2, 1 + 2 + """) + result = testdir.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines([ + "*AssertionError*3*", + "*assert 1 == 2*", + ]) def test_boolop(self): def f():