Merge pull request #5373 from asottile/revert_all_handling
Revert unrolling of `all()`
This commit is contained in:
		
						commit
						63099bc282
					
				| 
						 | 
					@ -1 +0,0 @@
 | 
				
			||||||
Fix assertion rewriting of ``all()`` calls to deal with non-generators.
 | 
					 | 
				
			||||||
| 
						 | 
					@ -0,0 +1 @@
 | 
				
			||||||
 | 
					Revert unrolling of ``all()`` to fix ``NameError`` on nested comprehensions.
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1 @@
 | 
				
			||||||
 | 
					Revert unrolling of ``all()`` to fix incorrect handling of generators with ``if``.
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1 @@
 | 
				
			||||||
 | 
					Revert unrolling of ``all()`` to fix incorrect assertion when using ``all()`` in an expression.
 | 
				
			||||||
| 
						 | 
					@ -903,22 +903,10 @@ warn_explicit(
 | 
				
			||||||
        res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
 | 
					        res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
 | 
				
			||||||
        return res, explanation
 | 
					        return res, explanation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    def _is_any_call_with_generator_or_list_comprehension(call):
 | 
					 | 
				
			||||||
        """Return True if the Call node is an 'any' call with a generator or list comprehension"""
 | 
					 | 
				
			||||||
        return (
 | 
					 | 
				
			||||||
            isinstance(call.func, ast.Name)
 | 
					 | 
				
			||||||
            and call.func.id == "all"
 | 
					 | 
				
			||||||
            and len(call.args) == 1
 | 
					 | 
				
			||||||
            and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp))
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def visit_Call(self, call):
 | 
					    def visit_Call(self, call):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        visit `ast.Call` nodes
 | 
					        visit `ast.Call` nodes
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if self._is_any_call_with_generator_or_list_comprehension(call):
 | 
					 | 
				
			||||||
            return self._visit_all(call)
 | 
					 | 
				
			||||||
        new_func, func_expl = self.visit(call.func)
 | 
					        new_func, func_expl = self.visit(call.func)
 | 
				
			||||||
        arg_expls = []
 | 
					        arg_expls = []
 | 
				
			||||||
        new_args = []
 | 
					        new_args = []
 | 
				
			||||||
| 
						 | 
					@ -942,25 +930,6 @@ warn_explicit(
 | 
				
			||||||
        outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
 | 
					        outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
 | 
				
			||||||
        return res, outer_expl
 | 
					        return res, outer_expl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _visit_all(self, call):
 | 
					 | 
				
			||||||
        """Special rewrite for the builtin all function, see #5062"""
 | 
					 | 
				
			||||||
        gen_exp = call.args[0]
 | 
					 | 
				
			||||||
        assertion_module = ast.Module(
 | 
					 | 
				
			||||||
            body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)]
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        AssertionRewriter(module_path=None, config=None).run(assertion_module)
 | 
					 | 
				
			||||||
        for_loop = ast.For(
 | 
					 | 
				
			||||||
            iter=gen_exp.generators[0].iter,
 | 
					 | 
				
			||||||
            target=gen_exp.generators[0].target,
 | 
					 | 
				
			||||||
            body=assertion_module.body,
 | 
					 | 
				
			||||||
            orelse=[],
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.statements.append(for_loop)
 | 
					 | 
				
			||||||
        return (
 | 
					 | 
				
			||||||
            ast.Num(n=1),
 | 
					 | 
				
			||||||
            "",
 | 
					 | 
				
			||||||
        )  # Return an empty expression, all the asserts are in the for_loop
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def visit_Starred(self, starred):
 | 
					    def visit_Starred(self, starred):
 | 
				
			||||||
        # From Python 3.5, a Starred node can appear in a function call
 | 
					        # From Python 3.5, a Starred node can appear in a function call
 | 
				
			||||||
        res, expl = self.visit(starred.value)
 | 
					        res, expl = self.visit(starred.value)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -635,12 +635,6 @@ class TestAssertionRewrite:
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            assert lines == ["assert 0 == 1\n +  where 1 = \\n{ \\n~ \\n}.a"]
 | 
					            assert lines == ["assert 0 == 1\n +  where 1 = \\n{ \\n~ \\n}.a"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_unroll_expression(self):
 | 
					 | 
				
			||||||
        def f():
 | 
					 | 
				
			||||||
            assert all(x == 1 for x in range(10))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        assert "0 == 1" in getmsg(f)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_custom_repr_non_ascii(self):
 | 
					    def test_custom_repr_non_ascii(self):
 | 
				
			||||||
        def f():
 | 
					        def f():
 | 
				
			||||||
            class A:
 | 
					            class A:
 | 
				
			||||||
| 
						 | 
					@ -656,78 +650,6 @@ class TestAssertionRewrite:
 | 
				
			||||||
        assert "UnicodeDecodeError" not in msg
 | 
					        assert "UnicodeDecodeError" not in msg
 | 
				
			||||||
        assert "UnicodeEncodeError" not in msg
 | 
					        assert "UnicodeEncodeError" not in msg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_unroll_all_generator(self, testdir):
 | 
					 | 
				
			||||||
        testdir.makepyfile(
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
            def check_even(num):
 | 
					 | 
				
			||||||
                if num % 2 == 0:
 | 
					 | 
				
			||||||
                    return True
 | 
					 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            def test_generator():
 | 
					 | 
				
			||||||
                odd_list = list(range(1,9,2))
 | 
					 | 
				
			||||||
                assert all(check_even(num) for num in odd_list)"""
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        result = testdir.runpytest()
 | 
					 | 
				
			||||||
        result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_unroll_all_list_comprehension(self, testdir):
 | 
					 | 
				
			||||||
        testdir.makepyfile(
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
            def check_even(num):
 | 
					 | 
				
			||||||
                if num % 2 == 0:
 | 
					 | 
				
			||||||
                    return True
 | 
					 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            def test_list_comprehension():
 | 
					 | 
				
			||||||
                odd_list = list(range(1,9,2))
 | 
					 | 
				
			||||||
                assert all([check_even(num) for num in odd_list])"""
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        result = testdir.runpytest()
 | 
					 | 
				
			||||||
        result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_unroll_all_object(self, testdir):
 | 
					 | 
				
			||||||
        """all() for non generators/non list-comprehensions (#5358)"""
 | 
					 | 
				
			||||||
        testdir.makepyfile(
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
            def test():
 | 
					 | 
				
			||||||
                assert all((1, 0))
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        result = testdir.runpytest()
 | 
					 | 
				
			||||||
        result.stdout.fnmatch_lines(["*assert False*", "*where False = all((1, 0))*"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_unroll_all_starred(self, testdir):
 | 
					 | 
				
			||||||
        """all() for non generators/non list-comprehensions (#5358)"""
 | 
					 | 
				
			||||||
        testdir.makepyfile(
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
            def test():
 | 
					 | 
				
			||||||
                x = ((1, 0),)
 | 
					 | 
				
			||||||
                assert all(*x)
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        result = testdir.runpytest()
 | 
					 | 
				
			||||||
        result.stdout.fnmatch_lines(
 | 
					 | 
				
			||||||
            ["*assert False*", "*where False = all(*((1, 0),))*"]
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_for_loop(self, testdir):
 | 
					 | 
				
			||||||
        testdir.makepyfile(
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
            def check_even(num):
 | 
					 | 
				
			||||||
                if num % 2 == 0:
 | 
					 | 
				
			||||||
                    return True
 | 
					 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            def test_for_loop():
 | 
					 | 
				
			||||||
                odd_list = list(range(1,9,2))
 | 
					 | 
				
			||||||
                for num in odd_list:
 | 
					 | 
				
			||||||
                    assert check_even(num)
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        result = testdir.runpytest()
 | 
					 | 
				
			||||||
        result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestRewriteOnImport:
 | 
					class TestRewriteOnImport:
 | 
				
			||||||
    def test_pycache_is_a_file(self, testdir):
 | 
					    def test_pycache_is_a_file(self, testdir):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue