refactor using self.variable

This commit is contained in:
Alessio Izzo 2023-03-02 22:06:14 +01:00
parent eab808739c
commit ea73f4a1d2
No known key found for this signature in database
GPG Key ID: 2B5983EE2D924936
2 changed files with 14 additions and 18 deletions

View File

@ -639,8 +639,12 @@ class AssertionRewriter(ast.NodeVisitor):
.push_format_context() and .pop_format_context() which allows
to build another %-formatted string while already building one.
This state is reset on every new assert statement visited and used
by the other visitors.
:variables_overwrite: A dict filled with references to variables
that change value within an assert. This happens when a variable is
reassigned with the walrus operator
This state, except the variables_overwrite, is reset on every new assert
statement visited and used by the other visitors.
"""
def __init__(
@ -974,7 +978,6 @@ class AssertionRewriter(ast.NodeVisitor):
levels = len(boolop.values) - 1
self.push_format_context()
# Process each operand, short-circuiting if needed.
pytest_temp = None
for i, v in enumerate(boolop.values):
if i:
fail_inner: List[ast.stmt] = []
@ -985,17 +988,14 @@ class AssertionRewriter(ast.NodeVisitor):
if (
isinstance(v, ast.Compare)
and isinstance(v.left, namedExpr)
and (
v.left.target.id
in [
ast_expr.id
for ast_expr in boolop.values[:i]
if hasattr(ast_expr, "id")
]
or v.left.target.id == pytest_temp
)
and v.left.target.id
in [
ast_expr.id
for ast_expr in boolop.values[:i]
if hasattr(ast_expr, "id")
]
):
pytest_temp = util.compose_temp_variable(v.left.target.id)
pytest_temp = self.variable()
self.variables_overwrite[v.left.target.id] = pytest_temp
v.left.target.id = pytest_temp
self.push_format_context()
@ -1092,7 +1092,7 @@ class AssertionRewriter(ast.NodeVisitor):
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
next_operand.target.id = util.compose_temp_variable(left_res.id)
next_operand.target.id = self.variable()
self.variables_overwrite[left_res.id] = next_operand.target.id
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):

View File

@ -520,7 +520,3 @@ def running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)
def compose_temp_variable(original_variable: str) -> str:
return f"pytest_{original_variable}_temp"