diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 26fed96b4..9f89d17fc 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -31,15 +31,16 @@ def pytest_configure(config): config._cleanup.append(m.undo) warn_about_missing_assertion() if not config.getvalue("noassert") and not config.getvalue("nomagic"): + from _pytest.assertion import reinterpret def callbinrepr(op, left, right): hook_result = config.hook.pytest_assertrepr_compare( config=config, op=op, left=left, right=right) for new_expl in hook_result: if new_expl: return '\n~'.join(new_expl) - m.setattr(py.builtin.builtins, - 'AssertionError', py.code._AssertionError) - m.setattr(py.code, '_reprcompare', callbinrepr) + m.setattr(py.builtin.builtins, 'AssertionError', + reinterpret.AssertionError) + m.setattr(sys.modules[__name__], '_reprcompare', callbinrepr) else: rewrite_asserts = None @@ -98,6 +99,53 @@ def warn_about_missing_assertion(): sys.stderr.write("WARNING: failing tests may report as passing because " "assertions are turned off! (are you using python -O?)\n") +# if set, will be called by assert reinterp for comparison ops +_reprcompare = None + +def _format_explanation(explanation): + """This formats an explanation + + Normally all embedded newlines are escaped, however there are + three exceptions: \n{, \n} and \n~. The first two are intended + cover nested explanations, see function and attribute explanations + for examples (.visit_Call(), visit_Attribute()). The last one is + for when one explanation needs to span multiple lines, e.g. when + displaying diffs. + """ + raw_lines = (explanation or '').split('\n') + # escape newlines not followed by {, } and ~ + lines = [raw_lines[0]] + for l in raw_lines[1:]: + if l.startswith('{') or l.startswith('}') or l.startswith('~'): + lines.append(l) + else: + lines[-1] += '\\n' + l + + result = lines[:1] + stack = [0] + stackcnt = [0] + for line in lines[1:]: + if line.startswith('{'): + if stackcnt[-1]: + s = 'and ' + else: + s = 'where ' + stack.append(len(result)) + stackcnt[-1] += 1 + stackcnt.append(0) + result.append(' +' + ' '*(len(stack)-1) + s + line[1:]) + elif line.startswith('}'): + assert line.startswith('}') + stack.pop() + stackcnt.pop() + result[stack[-1]] += line[1:] + else: + assert line.startswith('~') + result.append(' '*len(stack) + line[1:]) + assert len(stack) == 1 + return '\n'.join(result) + + # Provide basestring in python3 try: basestring = basestring diff --git a/_pytest/assertion/newinterpret.py b/_pytest/assertion/newinterpret.py new file mode 100644 index 000000000..1d061aa46 --- /dev/null +++ b/_pytest/assertion/newinterpret.py @@ -0,0 +1,340 @@ +""" +Find intermediate evalutation results in assert statements through builtin AST. +This should replace oldinterpret.py eventually. +""" + +import sys +import ast + +import py +from _pytest import assertion +from _pytest.assertion import _format_explanation +from _pytest.assertion.reinterpret import BuiltinAssertionError + + +if sys.platform.startswith("java") and sys.version_info < (2, 5, 2): + # See http://bugs.jython.org/issue1497 + _exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict", + "ListComp", "GeneratorExp", "Yield", "Compare", "Call", + "Repr", "Num", "Str", "Attribute", "Subscript", "Name", + "List", "Tuple") + _stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign", + "AugAssign", "Print", "For", "While", "If", "With", "Raise", + "TryExcept", "TryFinally", "Assert", "Import", "ImportFrom", + "Exec", "Global", "Expr", "Pass", "Break", "Continue") + _expr_nodes = set(getattr(ast, name) for name in _exprs) + _stmt_nodes = set(getattr(ast, name) for name in _stmts) + def _is_ast_expr(node): + return node.__class__ in _expr_nodes + def _is_ast_stmt(node): + return node.__class__ in _stmt_nodes +else: + def _is_ast_expr(node): + return isinstance(node, ast.expr) + def _is_ast_stmt(node): + return isinstance(node, ast.stmt) + + +class Failure(Exception): + """Error found while interpreting AST.""" + + def __init__(self, explanation=""): + self.cause = sys.exc_info() + self.explanation = explanation + + +def interpret(source, frame, should_fail=False): + mod = ast.parse(source) + visitor = DebugInterpreter(frame) + try: + visitor.visit(mod) + except Failure: + failure = sys.exc_info()[1] + return getfailure(failure) + if should_fail: + return ("(assertion failed, but when it was re-run for " + "printing intermediate values, it did not fail. Suggestions: " + "compute assert expression before the assert or use --no-assert)") + +def run(offending_line, frame=None): + if frame is None: + frame = py.code.Frame(sys._getframe(1)) + return interpret(offending_line, frame) + +def getfailure(failure): + explanation = _format_explanation(failure.explanation) + value = failure.cause[1] + if str(value): + lines = explanation.splitlines() + if not lines: + lines.append("") + lines[0] += " << %s" % (value,) + explanation = "\n".join(lines) + text = "%s: %s" % (failure.cause[0].__name__, explanation) + if text.startswith("AssertionError: assert "): + text = text[16:] + return text + + +operator_map = { + ast.BitOr : "|", + ast.BitXor : "^", + ast.BitAnd : "&", + ast.LShift : "<<", + ast.RShift : ">>", + ast.Add : "+", + ast.Sub : "-", + ast.Mult : "*", + ast.Div : "/", + ast.FloorDiv : "//", + ast.Mod : "%", + ast.Eq : "==", + ast.NotEq : "!=", + ast.Lt : "<", + ast.LtE : "<=", + ast.Gt : ">", + ast.GtE : ">=", + ast.Pow : "**", + ast.Is : "is", + ast.IsNot : "is not", + ast.In : "in", + ast.NotIn : "not in" +} + +unary_map = { + ast.Not : "not %s", + ast.Invert : "~%s", + ast.USub : "-%s", + ast.UAdd : "+%s" +} + + +class DebugInterpreter(ast.NodeVisitor): + """Interpret AST nodes to gleam useful debugging information. """ + + def __init__(self, frame): + self.frame = frame + + def generic_visit(self, node): + # Fallback when we don't have a special implementation. + if _is_ast_expr(node): + mod = ast.Expression(node) + co = self._compile(mod) + try: + result = self.frame.eval(co) + except Exception: + raise Failure() + explanation = self.frame.repr(result) + return explanation, result + elif _is_ast_stmt(node): + mod = ast.Module([node]) + co = self._compile(mod, "exec") + try: + self.frame.exec_(co) + except Exception: + raise Failure() + return None, None + else: + raise AssertionError("can't handle %s" %(node,)) + + def _compile(self, source, mode="eval"): + return compile(source, "", mode) + + def visit_Expr(self, expr): + return self.visit(expr.value) + + def visit_Module(self, mod): + for stmt in mod.body: + self.visit(stmt) + + def visit_Name(self, name): + explanation, result = self.generic_visit(name) + # See if the name is local. + source = "%r in locals() is not globals()" % (name.id,) + co = self._compile(source) + try: + local = self.frame.eval(co) + except Exception: + # have to assume it isn't + local = False + if not local: + return name.id, result + return explanation, result + + def visit_Compare(self, comp): + left = comp.left + left_explanation, left_result = self.visit(left) + for op, next_op in zip(comp.ops, comp.comparators): + next_explanation, next_result = self.visit(next_op) + op_symbol = operator_map[op.__class__] + explanation = "%s %s %s" % (left_explanation, op_symbol, + next_explanation) + source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,) + co = self._compile(source) + try: + result = self.frame.eval(co, __exprinfo_left=left_result, + __exprinfo_right=next_result) + except Exception: + raise Failure(explanation) + try: + if not result: + break + except KeyboardInterrupt: + raise + except: + break + left_explanation, left_result = next_explanation, next_result + + if assertion._reprcompare is not None: + res = assertion._reprcompare(op_symbol, left_result, next_result) + if res: + explanation = res + return explanation, result + + def visit_BoolOp(self, boolop): + is_or = isinstance(boolop.op, ast.Or) + explanations = [] + for operand in boolop.values: + explanation, result = self.visit(operand) + explanations.append(explanation) + if result == is_or: + break + name = is_or and " or " or " and " + explanation = "(" + name.join(explanations) + ")" + return explanation, result + + def visit_UnaryOp(self, unary): + pattern = unary_map[unary.op.__class__] + operand_explanation, operand_result = self.visit(unary.operand) + explanation = pattern % (operand_explanation,) + co = self._compile(pattern % ("__exprinfo_expr",)) + try: + result = self.frame.eval(co, __exprinfo_expr=operand_result) + except Exception: + raise Failure(explanation) + return explanation, result + + def visit_BinOp(self, binop): + left_explanation, left_result = self.visit(binop.left) + right_explanation, right_result = self.visit(binop.right) + symbol = operator_map[binop.op.__class__] + explanation = "(%s %s %s)" % (left_explanation, symbol, + right_explanation) + source = "__exprinfo_left %s __exprinfo_right" % (symbol,) + co = self._compile(source) + try: + result = self.frame.eval(co, __exprinfo_left=left_result, + __exprinfo_right=right_result) + except Exception: + raise Failure(explanation) + return explanation, result + + def visit_Call(self, call): + func_explanation, func = self.visit(call.func) + arg_explanations = [] + ns = {"__exprinfo_func" : func} + arguments = [] + for arg in call.args: + arg_explanation, arg_result = self.visit(arg) + arg_name = "__exprinfo_%s" % (len(ns),) + ns[arg_name] = arg_result + arguments.append(arg_name) + arg_explanations.append(arg_explanation) + for keyword in call.keywords: + arg_explanation, arg_result = self.visit(keyword.value) + arg_name = "__exprinfo_%s" % (len(ns),) + ns[arg_name] = arg_result + keyword_source = "%s=%%s" % (keyword.arg) + arguments.append(keyword_source % (arg_name,)) + arg_explanations.append(keyword_source % (arg_explanation,)) + if call.starargs: + arg_explanation, arg_result = self.visit(call.starargs) + arg_name = "__exprinfo_star" + ns[arg_name] = arg_result + arguments.append("*%s" % (arg_name,)) + arg_explanations.append("*%s" % (arg_explanation,)) + if call.kwargs: + arg_explanation, arg_result = self.visit(call.kwargs) + arg_name = "__exprinfo_kwds" + ns[arg_name] = arg_result + arguments.append("**%s" % (arg_name,)) + arg_explanations.append("**%s" % (arg_explanation,)) + args_explained = ", ".join(arg_explanations) + explanation = "%s(%s)" % (func_explanation, args_explained) + args = ", ".join(arguments) + source = "__exprinfo_func(%s)" % (args,) + co = self._compile(source) + try: + result = self.frame.eval(co, **ns) + except Exception: + raise Failure(explanation) + pattern = "%s\n{%s = %s\n}" + rep = self.frame.repr(result) + explanation = pattern % (rep, rep, explanation) + return explanation, result + + def _is_builtin_name(self, name): + pattern = "%r not in globals() and %r not in locals()" + source = pattern % (name.id, name.id) + co = self._compile(source) + try: + return self.frame.eval(co) + except Exception: + return False + + def visit_Attribute(self, attr): + if not isinstance(attr.ctx, ast.Load): + return self.generic_visit(attr) + source_explanation, source_result = self.visit(attr.value) + explanation = "%s.%s" % (source_explanation, attr.attr) + source = "__exprinfo_expr.%s" % (attr.attr,) + co = self._compile(source) + try: + result = self.frame.eval(co, __exprinfo_expr=source_result) + except Exception: + raise Failure(explanation) + explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result), + self.frame.repr(result), + source_explanation, attr.attr) + # Check if the attr is from an instance. + source = "%r in getattr(__exprinfo_expr, '__dict__', {})" + source = source % (attr.attr,) + co = self._compile(source) + try: + from_instance = self.frame.eval(co, __exprinfo_expr=source_result) + except Exception: + from_instance = True + if from_instance: + rep = self.frame.repr(result) + pattern = "%s\n{%s = %s\n}" + explanation = pattern % (rep, rep, explanation) + return explanation, result + + def visit_Assert(self, assrt): + test_explanation, test_result = self.visit(assrt.test) + if test_explanation.startswith("False\n{False =") and \ + test_explanation.endswith("\n"): + test_explanation = test_explanation[15:-2] + explanation = "assert %s" % (test_explanation,) + if not test_result: + try: + raise BuiltinAssertionError + except Exception: + raise Failure(explanation) + return explanation, test_result + + def visit_Assign(self, assign): + value_explanation, value_result = self.visit(assign.value) + explanation = "... = %s" % (value_explanation,) + name = ast.Name("__exprinfo_expr", ast.Load(), + lineno=assign.value.lineno, + col_offset=assign.value.col_offset) + new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno, + col_offset=assign.col_offset) + mod = ast.Module([new_assign]) + co = self._compile(mod, "exec") + try: + self.frame.exec_(co, __exprinfo_expr=value_result) + except Exception: + raise Failure(explanation) + return explanation, value_result diff --git a/_pytest/assertion/oldinterpret.py b/_pytest/assertion/oldinterpret.py new file mode 100644 index 000000000..3e8f1c0b3 --- /dev/null +++ b/_pytest/assertion/oldinterpret.py @@ -0,0 +1,556 @@ +import py +import sys, inspect +from compiler import parse, ast, pycodegen +from _pytest.assertion import _format_explanation +from _pytest.assertion.reinterpret import BuiltinAssertionError + +passthroughex = py.builtin._sysex + +class Failure: + def __init__(self, node): + self.exc, self.value, self.tb = sys.exc_info() + self.node = node + +class View(object): + """View base class. + + If C is a subclass of View, then C(x) creates a proxy object around + the object x. The actual class of the proxy is not C in general, + but a *subclass* of C determined by the rules below. To avoid confusion + we call view class the class of the proxy (a subclass of C, so of View) + and object class the class of x. + + Attributes and methods not found in the proxy are automatically read on x. + Other operations like setting attributes are performed on the proxy, as + determined by its view class. The object x is available from the proxy + as its __obj__ attribute. + + The view class selection is determined by the __view__ tuples and the + optional __viewkey__ method. By default, the selected view class is the + most specific subclass of C whose __view__ mentions the class of x. + If no such subclass is found, the search proceeds with the parent + object classes. For example, C(True) will first look for a subclass + of C with __view__ = (..., bool, ...) and only if it doesn't find any + look for one with __view__ = (..., int, ...), and then ..., object,... + If everything fails the class C itself is considered to be the default. + + Alternatively, the view class selection can be driven by another aspect + of the object x, instead of the class of x, by overriding __viewkey__. + See last example at the end of this module. + """ + + _viewcache = {} + __view__ = () + + def __new__(rootclass, obj, *args, **kwds): + self = object.__new__(rootclass) + self.__obj__ = obj + self.__rootclass__ = rootclass + key = self.__viewkey__() + try: + self.__class__ = self._viewcache[key] + except KeyError: + self.__class__ = self._selectsubclass(key) + return self + + def __getattr__(self, attr): + # attributes not found in the normal hierarchy rooted on View + # are looked up in the object's real class + return getattr(self.__obj__, attr) + + def __viewkey__(self): + return self.__obj__.__class__ + + def __matchkey__(self, key, subclasses): + if inspect.isclass(key): + keys = inspect.getmro(key) + else: + keys = [key] + for key in keys: + result = [C for C in subclasses if key in C.__view__] + if result: + return result + return [] + + def _selectsubclass(self, key): + subclasses = list(enumsubclasses(self.__rootclass__)) + for C in subclasses: + if not isinstance(C.__view__, tuple): + C.__view__ = (C.__view__,) + choices = self.__matchkey__(key, subclasses) + if not choices: + return self.__rootclass__ + elif len(choices) == 1: + return choices[0] + else: + # combine the multiple choices + return type('?', tuple(choices), {}) + + def __repr__(self): + return '%s(%r)' % (self.__rootclass__.__name__, self.__obj__) + + +def enumsubclasses(cls): + for subcls in cls.__subclasses__(): + for subsubclass in enumsubclasses(subcls): + yield subsubclass + yield cls + + +class Interpretable(View): + """A parse tree node with a few extra methods.""" + explanation = None + + def is_builtin(self, frame): + return False + + def eval(self, frame): + # fall-back for unknown expression nodes + try: + expr = ast.Expression(self.__obj__) + expr.filename = '' + self.__obj__.filename = '' + co = pycodegen.ExpressionCodeGenerator(expr).getCode() + result = frame.eval(co) + except passthroughex: + raise + except: + raise Failure(self) + self.result = result + self.explanation = self.explanation or frame.repr(self.result) + + def run(self, frame): + # fall-back for unknown statement nodes + try: + expr = ast.Module(None, ast.Stmt([self.__obj__])) + expr.filename = '' + co = pycodegen.ModuleCodeGenerator(expr).getCode() + frame.exec_(co) + except passthroughex: + raise + except: + raise Failure(self) + + def nice_explanation(self): + return _format_explanation(self.explanation) + + +class Name(Interpretable): + __view__ = ast.Name + + def is_local(self, frame): + source = '%r in locals() is not globals()' % self.name + try: + return frame.is_true(frame.eval(source)) + except passthroughex: + raise + except: + return False + + def is_global(self, frame): + source = '%r in globals()' % self.name + try: + return frame.is_true(frame.eval(source)) + except passthroughex: + raise + except: + return False + + def is_builtin(self, frame): + source = '%r not in locals() and %r not in globals()' % ( + self.name, self.name) + try: + return frame.is_true(frame.eval(source)) + except passthroughex: + raise + except: + return False + + def eval(self, frame): + super(Name, self).eval(frame) + if not self.is_local(frame): + self.explanation = self.name + +class Compare(Interpretable): + __view__ = ast.Compare + + def eval(self, frame): + expr = Interpretable(self.expr) + expr.eval(frame) + for operation, expr2 in self.ops: + if hasattr(self, 'result'): + # shortcutting in chained expressions + if not frame.is_true(self.result): + break + expr2 = Interpretable(expr2) + expr2.eval(frame) + self.explanation = "%s %s %s" % ( + expr.explanation, operation, expr2.explanation) + source = "__exprinfo_left %s __exprinfo_right" % operation + try: + self.result = frame.eval(source, + __exprinfo_left=expr.result, + __exprinfo_right=expr2.result) + except passthroughex: + raise + except: + raise Failure(self) + expr = expr2 + +class And(Interpretable): + __view__ = ast.And + + def eval(self, frame): + explanations = [] + for expr in self.nodes: + expr = Interpretable(expr) + expr.eval(frame) + explanations.append(expr.explanation) + self.result = expr.result + if not frame.is_true(expr.result): + break + self.explanation = '(' + ' and '.join(explanations) + ')' + +class Or(Interpretable): + __view__ = ast.Or + + def eval(self, frame): + explanations = [] + for expr in self.nodes: + expr = Interpretable(expr) + expr.eval(frame) + explanations.append(expr.explanation) + self.result = expr.result + if frame.is_true(expr.result): + break + self.explanation = '(' + ' or '.join(explanations) + ')' + + +# == Unary operations == +keepalive = [] +for astclass, astpattern in { + ast.Not : 'not __exprinfo_expr', + ast.Invert : '(~__exprinfo_expr)', + }.items(): + + class UnaryArith(Interpretable): + __view__ = astclass + + def eval(self, frame, astpattern=astpattern): + expr = Interpretable(self.expr) + expr.eval(frame) + self.explanation = astpattern.replace('__exprinfo_expr', + expr.explanation) + try: + self.result = frame.eval(astpattern, + __exprinfo_expr=expr.result) + except passthroughex: + raise + except: + raise Failure(self) + + keepalive.append(UnaryArith) + +# == Binary operations == +for astclass, astpattern in { + ast.Add : '(__exprinfo_left + __exprinfo_right)', + ast.Sub : '(__exprinfo_left - __exprinfo_right)', + ast.Mul : '(__exprinfo_left * __exprinfo_right)', + ast.Div : '(__exprinfo_left / __exprinfo_right)', + ast.Mod : '(__exprinfo_left % __exprinfo_right)', + ast.Power : '(__exprinfo_left ** __exprinfo_right)', + }.items(): + + class BinaryArith(Interpretable): + __view__ = astclass + + def eval(self, frame, astpattern=astpattern): + left = Interpretable(self.left) + left.eval(frame) + right = Interpretable(self.right) + right.eval(frame) + self.explanation = (astpattern + .replace('__exprinfo_left', left .explanation) + .replace('__exprinfo_right', right.explanation)) + try: + self.result = frame.eval(astpattern, + __exprinfo_left=left.result, + __exprinfo_right=right.result) + except passthroughex: + raise + except: + raise Failure(self) + + keepalive.append(BinaryArith) + + +class CallFunc(Interpretable): + __view__ = ast.CallFunc + + def is_bool(self, frame): + source = 'isinstance(__exprinfo_value, bool)' + try: + return frame.is_true(frame.eval(source, + __exprinfo_value=self.result)) + except passthroughex: + raise + except: + return False + + def eval(self, frame): + node = Interpretable(self.node) + node.eval(frame) + explanations = [] + vars = {'__exprinfo_fn': node.result} + source = '__exprinfo_fn(' + for a in self.args: + if isinstance(a, ast.Keyword): + keyword = a.name + a = a.expr + else: + keyword = None + a = Interpretable(a) + a.eval(frame) + argname = '__exprinfo_%d' % len(vars) + vars[argname] = a.result + if keyword is None: + source += argname + ',' + explanations.append(a.explanation) + else: + source += '%s=%s,' % (keyword, argname) + explanations.append('%s=%s' % (keyword, a.explanation)) + if self.star_args: + star_args = Interpretable(self.star_args) + star_args.eval(frame) + argname = '__exprinfo_star' + vars[argname] = star_args.result + source += '*' + argname + ',' + explanations.append('*' + star_args.explanation) + if self.dstar_args: + dstar_args = Interpretable(self.dstar_args) + dstar_args.eval(frame) + argname = '__exprinfo_kwds' + vars[argname] = dstar_args.result + source += '**' + argname + ',' + explanations.append('**' + dstar_args.explanation) + self.explanation = "%s(%s)" % ( + node.explanation, ', '.join(explanations)) + if source.endswith(','): + source = source[:-1] + source += ')' + try: + self.result = frame.eval(source, **vars) + except passthroughex: + raise + except: + raise Failure(self) + if not node.is_builtin(frame) or not self.is_bool(frame): + r = frame.repr(self.result) + self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation) + +class Getattr(Interpretable): + __view__ = ast.Getattr + + def eval(self, frame): + expr = Interpretable(self.expr) + expr.eval(frame) + source = '__exprinfo_expr.%s' % self.attrname + try: + self.result = frame.eval(source, __exprinfo_expr=expr.result) + except passthroughex: + raise + except: + raise Failure(self) + self.explanation = '%s.%s' % (expr.explanation, self.attrname) + # if the attribute comes from the instance, its value is interesting + source = ('hasattr(__exprinfo_expr, "__dict__") and ' + '%r in __exprinfo_expr.__dict__' % self.attrname) + try: + from_instance = frame.is_true( + frame.eval(source, __exprinfo_expr=expr.result)) + except passthroughex: + raise + except: + from_instance = True + if from_instance: + r = frame.repr(self.result) + self.explanation = '%s\n{%s = %s\n}' % (r, r, self.explanation) + +# == Re-interpretation of full statements == + +class Assert(Interpretable): + __view__ = ast.Assert + + def run(self, frame): + test = Interpretable(self.test) + test.eval(frame) + # simplify 'assert False where False = ...' + if (test.explanation.startswith('False\n{False = ') and + test.explanation.endswith('\n}')): + test.explanation = test.explanation[15:-2] + # print the result as 'assert ' + self.result = test.result + self.explanation = 'assert ' + test.explanation + if not frame.is_true(test.result): + try: + raise BuiltinAssertionError + except passthroughex: + raise + except: + raise Failure(self) + +class Assign(Interpretable): + __view__ = ast.Assign + + def run(self, frame): + expr = Interpretable(self.expr) + expr.eval(frame) + self.result = expr.result + self.explanation = '... = ' + expr.explanation + # fall-back-run the rest of the assignment + ass = ast.Assign(self.nodes, ast.Name('__exprinfo_expr')) + mod = ast.Module(None, ast.Stmt([ass])) + mod.filename = '' + co = pycodegen.ModuleCodeGenerator(mod).getCode() + try: + frame.exec_(co, __exprinfo_expr=expr.result) + except passthroughex: + raise + except: + raise Failure(self) + +class Discard(Interpretable): + __view__ = ast.Discard + + def run(self, frame): + expr = Interpretable(self.expr) + expr.eval(frame) + self.result = expr.result + self.explanation = expr.explanation + +class Stmt(Interpretable): + __view__ = ast.Stmt + + def run(self, frame): + for stmt in self.nodes: + stmt = Interpretable(stmt) + stmt.run(frame) + + +def report_failure(e): + explanation = e.node.nice_explanation() + if explanation: + explanation = ", in: " + explanation + else: + explanation = "" + sys.stdout.write("%s: %s%s\n" % (e.exc.__name__, e.value, explanation)) + +def check(s, frame=None): + if frame is None: + frame = sys._getframe(1) + frame = py.code.Frame(frame) + expr = parse(s, 'eval') + assert isinstance(expr, ast.Expression) + node = Interpretable(expr.node) + try: + node.eval(frame) + except passthroughex: + raise + except Failure: + e = sys.exc_info()[1] + report_failure(e) + else: + if not frame.is_true(node.result): + sys.stderr.write("assertion failed: %s\n" % node.nice_explanation()) + + +########################################################### +# API / Entry points +# ######################################################### + +def interpret(source, frame, should_fail=False): + module = Interpretable(parse(source, 'exec').node) + #print "got module", module + if isinstance(frame, py.std.types.FrameType): + frame = py.code.Frame(frame) + try: + module.run(frame) + except Failure: + e = sys.exc_info()[1] + return getfailure(e) + except passthroughex: + raise + except: + import traceback + traceback.print_exc() + if should_fail: + return ("(assertion failed, but when it was re-run for " + "printing intermediate values, it did not fail. Suggestions: " + "compute assert expression before the assert or use --nomagic)") + else: + return None + +def getmsg(excinfo): + if isinstance(excinfo, tuple): + excinfo = py.code.ExceptionInfo(excinfo) + #frame, line = gettbline(tb) + #frame = py.code.Frame(frame) + #return interpret(line, frame) + + tb = excinfo.traceback[-1] + source = str(tb.statement).strip() + x = interpret(source, tb.frame, should_fail=True) + if not isinstance(x, str): + raise TypeError("interpret returned non-string %r" % (x,)) + return x + +def getfailure(e): + explanation = e.node.nice_explanation() + if str(e.value): + lines = explanation.split('\n') + lines[0] += " << %s" % (e.value,) + explanation = '\n'.join(lines) + text = "%s: %s" % (e.exc.__name__, explanation) + if text.startswith('AssertionError: assert '): + text = text[16:] + return text + +def run(s, frame=None): + if frame is None: + frame = sys._getframe(1) + frame = py.code.Frame(frame) + module = Interpretable(parse(s, 'exec').node) + try: + module.run(frame) + except Failure: + e = sys.exc_info()[1] + report_failure(e) + + +if __name__ == '__main__': + # example: + def f(): + return 5 + def g(): + return 3 + def h(x): + return 'never' + check("f() * g() == 5") + check("not f()") + check("not (f() and g() or 0)") + check("f() == g()") + i = 4 + check("i == f()") + check("len(f()) == 0") + check("isinstance(2+3+4, float)") + + run("x = i") + check("x == 5") + + run("assert not f(), 'oops'") + run("a, b, c = 1, 2") + run("a, b, c = f()") + + check("max([f(),g()]) == 4") + check("'hello'[g()] == 'h'") + run("'guk%d' % h(f())") diff --git a/_pytest/assertion/reinterpret.py b/_pytest/assertion/reinterpret.py new file mode 100644 index 000000000..6e9465d8a --- /dev/null +++ b/_pytest/assertion/reinterpret.py @@ -0,0 +1,48 @@ +import sys +import py + +BuiltinAssertionError = py.builtin.builtins.AssertionError + +class AssertionError(BuiltinAssertionError): + def __init__(self, *args): + BuiltinAssertionError.__init__(self, *args) + if args: + try: + self.msg = str(args[0]) + except py.builtin._sysex: + raise + except: + self.msg = "<[broken __repr__] %s at %0xd>" %( + args[0].__class__, id(args[0])) + else: + f = py.code.Frame(sys._getframe(1)) + try: + source = f.code.fullsource + if source is not None: + try: + source = source.getstatement(f.lineno, assertion=True) + except IndexError: + source = None + else: + source = str(source.deindent()).strip() + except py.error.ENOENT: + source = None + # this can also occur during reinterpretation, when the + # co_filename is set to "". + if source: + self.msg = reinterpret(source, f, should_fail=True) + else: + self.msg = "" + if not self.args: + self.args = (self.msg,) + +if sys.version_info > (3, 0): + AssertionError.__module__ = "builtins" + reinterpret_old = "old reinterpretation not available for py3" +else: + from _pytest.assertion.oldinterpret import interpret as reinterpret_old +if sys.version_info >= (2, 6) or (sys.platform.startswith("java")): + from _pytest.assertion.newinterpret import interpret as reinterpret +else: + reinterpret = reinterpret_old + diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 29ce43869..186d2425e 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -13,7 +13,6 @@ def rewrite_asserts(mod): _saferepr = py.io.saferepr -_format_explanation = py.code._format_explanation def _format_boolop(operands, explanations, is_or): show_explanations = [] @@ -31,8 +30,9 @@ def _call_reprcompare(ops, results, expls, each_obj): done = True if done: break - if py.code._reprcompare is not None: - custom = py.code._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) + from _pytest.assertion import _reprcompare + if _reprcompare is not None: + custom = _reprcompare(ops[i], each_obj[i], each_obj[i + 1]) if custom is not None: return custom return expl @@ -94,7 +94,7 @@ class AssertionRewriter(ast.NodeVisitor): # Insert some special imports at the top of the module but after any # docstrings and __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), - ast.alias("py", "@pylib"), + ast.alias("_pytest.assertion", "@pytest_a"), ast.alias("_pytest.assertion.rewrite", "@pytest_ar")] expect_docstring = True pos = 0 @@ -153,11 +153,11 @@ class AssertionRewriter(ast.NodeVisitor): def display(self, expr): """Call py.io.saferepr on the expression.""" - return self.helper("saferepr", expr) + return self.helper("ar", "saferepr", expr) - def helper(self, name, *args): + def helper(self, mod, name, *args): """Call a helper in this module.""" - py_name = ast.Name("@pytest_ar", ast.Load()) + py_name = ast.Name("@pytest_" + mod, ast.Load()) attr = ast.Attribute(py_name, "_" + name, ast.Load()) return ast.Call(attr, list(args), [], None, None) @@ -211,7 +211,7 @@ class AssertionRewriter(ast.NodeVisitor): explanation = "assert " + explanation template = ast.Str(explanation) msg = self.pop_format_context(template) - fmt = self.helper("format_explanation", msg) + fmt = self.helper("a", "format_explanation", msg) body.append(ast.Assert(top_condition, fmt)) # Delete temporary variables. names = [ast.Name(name, ast.Del()) for name in self.variables] @@ -242,7 +242,7 @@ class AssertionRewriter(ast.NodeVisitor): explanations.append(explanation) expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load()) is_or = ast.Num(isinstance(boolop.op, ast.Or)) - expl_template = self.helper("format_boolop", + expl_template = self.helper("ar", "format_boolop", ast.Tuple(operands, ast.Load()), expls, is_or) expl = self.pop_format_context(expl_template) @@ -321,7 +321,8 @@ class AssertionRewriter(ast.NodeVisitor): self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl # Use py.code._reprcompare if that's available. - expl_call = self.helper("call_reprcompare", ast.Tuple(syms, ast.Load()), + expl_call = self.helper("ar", "call_reprcompare", + ast.Tuple(syms, ast.Load()), ast.Tuple(load_names, ast.Load()), ast.Tuple(expls, ast.Load()), ast.Tuple(results, ast.Load())) diff --git a/testing/test_assertinterpret.py b/testing/test_assertinterpret.py new file mode 100644 index 000000000..318516eae --- /dev/null +++ b/testing/test_assertinterpret.py @@ -0,0 +1,327 @@ +"PYTEST_DONT_REWRITE" +import pytest, py + +from _pytest import assertion + +def exvalue(): + return py.std.sys.exc_info()[1] + +def f(): + return 2 + +def test_not_being_rewritten(): + assert "@py_builtins" not in globals() + +def test_assert(): + try: + assert f() == 3 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith('assert 2 == 3\n') + +def test_assert_with_explicit_message(): + try: + assert f() == 3, "hello" + except AssertionError: + e = exvalue() + assert e.msg == 'hello' + +def test_assert_within_finally(): + class A: + def f(): + pass + excinfo = py.test.raises(TypeError, """ + try: + A().f() + finally: + i = 42 + """) + s = excinfo.exconly() + assert s.find("takes no argument") != -1 + + #def g(): + # A.f() + #excinfo = getexcinfo(TypeError, g) + #msg = getmsg(excinfo) + #assert msg.find("must be called with A") != -1 + + +def test_assert_multiline_1(): + try: + assert (f() == + 3) + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith('assert 2 == 3\n') + +def test_assert_multiline_2(): + try: + assert (f() == (4, + 3)[-1]) + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith('assert 2 ==') + +def test_in(): + try: + assert "hi" in [1, 2] + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 'hi' in") + +def test_is(): + try: + assert 1 is 2 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 1 is 2") + + +@py.test.mark.skipif("sys.version_info < (2,6)") +def test_attrib(): + class Foo(object): + b = 1 + i = Foo() + try: + assert i.b == 2 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 1 == 2") + +@py.test.mark.skipif("sys.version_info < (2,6)") +def test_attrib_inst(): + class Foo(object): + b = 1 + try: + assert Foo().b == 2 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 1 == 2") + +def test_len(): + l = list(range(42)) + try: + assert len(l) == 100 + except AssertionError: + e = exvalue() + s = str(e) + assert s.startswith("assert 42 == 100") + assert "where 42 = len([" in s + +def test_assert_non_string_message(): + class A: + def __str__(self): + return "hello" + try: + assert 0 == 1, A() + except AssertionError: + e = exvalue() + assert e.msg == "hello" + +def test_assert_keyword_arg(): + def f(x=3): + return False + try: + assert f(x=5) + except AssertionError: + e = exvalue() + assert "x=5" in e.msg + +# These tests should both fail, but should fail nicely... +class WeirdRepr: + def __repr__(self): + return '' + +def bug_test_assert_repr(): + v = WeirdRepr() + try: + assert v == 1 + except AssertionError: + e = exvalue() + assert e.msg.find('WeirdRepr') != -1 + assert e.msg.find('second line') != -1 + assert 0 + +def test_assert_non_string(): + try: + assert 0, ['list'] + except AssertionError: + e = exvalue() + assert e.msg.find("list") != -1 + +def test_assert_implicit_multiline(): + try: + x = [1,2,3] + assert x != [1, + 2, 3] + except AssertionError: + e = exvalue() + assert e.msg.find('assert [1, 2, 3] !=') != -1 + + +def test_assert_with_brokenrepr_arg(): + class BrokenRepr: + def __repr__(self): 0 / 0 + e = AssertionError(BrokenRepr()) + if e.msg.find("broken __repr__") == -1: + py.test.fail("broken __repr__ not handle correctly") + +def test_multiple_statements_per_line(): + try: + a = 1; assert a == 2 + except AssertionError: + e = exvalue() + assert "assert 1 == 2" in e.msg + +def test_power(): + try: + assert 2**3 == 7 + except AssertionError: + e = exvalue() + assert "assert (2 ** 3) == 7" in e.msg + + +class TestView: + + def setup_class(cls): + cls.View = pytest.importorskip("_pytest.assertion.oldinterpret").View + + def test_class_dispatch(self): + ### Use a custom class hierarchy with existing instances + + class Picklable(self.View): + pass + + class Simple(Picklable): + __view__ = object + def pickle(self): + return repr(self.__obj__) + + class Seq(Picklable): + __view__ = list, tuple, dict + def pickle(self): + return ';'.join( + [Picklable(item).pickle() for item in self.__obj__]) + + class Dict(Seq): + __view__ = dict + def pickle(self): + return Seq.pickle(self) + '!' + Seq(self.values()).pickle() + + assert Picklable(123).pickle() == '123' + assert Picklable([1,[2,3],4]).pickle() == '1;2;3;4' + assert Picklable({1:2}).pickle() == '1!2' + + def test_viewtype_class_hierarchy(self): + # Use a custom class hierarchy based on attributes of existing instances + class Operation: + "Existing class that I don't want to change." + def __init__(self, opname, *args): + self.opname = opname + self.args = args + + existing = [Operation('+', 4, 5), + Operation('getitem', '', 'join'), + Operation('setattr', 'x', 'y', 3), + Operation('-', 12, 1)] + + class PyOp(self.View): + def __viewkey__(self): + return self.opname + def generate(self): + return '%s(%s)' % (self.opname, ', '.join(map(repr, self.args))) + + class PyBinaryOp(PyOp): + __view__ = ('+', '-', '*', '/') + def generate(self): + return '%s %s %s' % (self.args[0], self.opname, self.args[1]) + + codelines = [PyOp(op).generate() for op in existing] + assert codelines == ["4 + 5", "getitem('', 'join')", + "setattr('x', 'y', 3)", "12 - 1"] + +@py.test.mark.skipif("sys.version_info < (2,6)") +def test_assert_customizable_reprcompare(monkeypatch): + monkeypatch.setattr(assertion, '_reprcompare', lambda *args: 'hello') + try: + assert 3 == 4 + except AssertionError: + e = exvalue() + s = str(e) + assert "hello" in s + +def test_assert_long_source_1(): + try: + assert len == [ + (None, ['somet text', 'more text']), + ] + except AssertionError: + e = exvalue() + s = str(e) + assert 're-run' not in s + assert 'somet text' in s + +def test_assert_long_source_2(): + try: + assert(len == [ + (None, ['somet text', 'more text']), + ]) + except AssertionError: + e = exvalue() + s = str(e) + assert 're-run' not in s + assert 'somet text' in s + +def test_assert_raise_alias(testdir): + testdir.makepyfile(""" + "PYTEST_DONT_REWRITE" + import sys + EX = AssertionError + def test_hello(): + raise EX("hello" + "multi" + "line") + """) + result = testdir.runpytest() + result.stdout.fnmatch_lines([ + "*def test_hello*", + "*raise EX*", + "*1 failed*", + ]) + + +@pytest.mark.skipif("sys.version_info < (2,5)") +def test_assert_raise_subclass(): + class SomeEx(AssertionError): + def __init__(self, *args): + super(SomeEx, self).__init__() + try: + raise SomeEx("hello") + except AssertionError: + s = str(exvalue()) + assert 're-run' not in s + assert 'could not determine' in s + +def test_assert_raises_in_nonzero_of_object_pytest_issue10(): + class A(object): + def __nonzero__(self): + raise ValueError(42) + def __lt__(self, other): + return A() + def __repr__(self): + return "" + def myany(x): + return True + try: + assert not(myany(A() < 0)) + except AssertionError: + e = exvalue() + s = str(e) + assert " < 0" in s diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 567cebbf1..5470f6416 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -2,11 +2,12 @@ import sys import py, pytest import _pytest.assertion as plugin +from _pytest.assertion import reinterpret needsnewassert = pytest.mark.skipif("sys.version_info < (2,6)") def interpret(expr): - return py.code._reinterpret(expr, py.code.Frame(sys._getframe(1))) + return reinterpret.reinterpret(expr, py.code.Frame(sys._getframe(1))) class TestBinReprIntegration: pytestmark = needsnewassert @@ -25,7 +26,7 @@ class TestBinReprIntegration: self.right = right mockhook = MockHook() monkeypatch = request.getfuncargvalue("monkeypatch") - monkeypatch.setattr(py.code, '_reprcompare', mockhook) + monkeypatch.setattr(plugin, '_reprcompare', mockhook) return mockhook def test_pytest_assertrepr_compare_called(self, hook): @@ -40,13 +41,13 @@ class TestBinReprIntegration: assert hook.right == [0, 2] def test_configure_unconfigure(self, testdir, hook): - assert hook == py.code._reprcompare + assert hook == plugin._reprcompare config = testdir.parseconfig() plugin.pytest_configure(config) - assert hook != py.code._reprcompare + assert hook != plugin._reprcompare from _pytest.config import pytest_unconfigure pytest_unconfigure(config) - assert hook == py.code._reprcompare + assert hook == plugin._reprcompare def callequal(left, right): return plugin.pytest_assertrepr_compare('==', left, right) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 580eed420..d713b6e25 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -4,15 +4,16 @@ import pytest ast = pytest.importorskip("ast") +from _pytest import assertion from _pytest.assertion.rewrite import rewrite_asserts def setup_module(mod): - mod._old_reprcompare = py.code._reprcompare + mod._old_reprcompare = assertion._reprcompare py.code._reprcompare = None def teardown_module(mod): - py.code._reprcompare = mod._old_reprcompare + assertion._reprcompare = mod._old_reprcompare del mod._old_reprcompare @@ -229,13 +230,13 @@ class TestAssertionRewrite: def test_custom_reprcompare(self, monkeypatch): def my_reprcompare(op, left, right): return "42" - monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare) + monkeypatch.setattr(assertion, "_reprcompare", my_reprcompare) def f(): assert 42 < 3 assert getmsg(f) == "assert 42" def my_reprcompare(op, left, right): return "%s %s %s" % (left, op, right) - monkeypatch.setattr(py.code, "_reprcompare", my_reprcompare) + monkeypatch.setattr(assertion, "_reprcompare", my_reprcompare) def f(): assert 1 < 3 < 5 <= 4 < 7 assert getmsg(f) == "assert 5 <= 4"