refactor trying to clean the code and add comments where conditions on instances of walrus operator
This commit is contained in:
parent
52f818d2be
commit
0bc4bdc063
|
@ -52,7 +52,6 @@ else:
|
||||||
|
|
||||||
assertstate_key = StashKey["AssertionState"]()
|
assertstate_key = StashKey["AssertionState"]()
|
||||||
|
|
||||||
|
|
||||||
# pytest caches rewritten pycs in pycache dirs
|
# pytest caches rewritten pycs in pycache dirs
|
||||||
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
|
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
|
||||||
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
||||||
|
@ -945,7 +944,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
|
|
||||||
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
|
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
|
||||||
# This method handles the 'walrus operator' repr of the target
|
# This method handles the 'walrus operator' repr of the target
|
||||||
# name if it's a local variable or _should_repr_global_name() thinks it's acceptable.
|
# name if it's a local variable or _should_repr_global_name()
|
||||||
|
# thinks it's acceptable.
|
||||||
locs = ast.Call(self.builtin("locals"), [], [])
|
locs = ast.Call(self.builtin("locals"), [], [])
|
||||||
target_id = name.target.id # type: ignore[attr-defined]
|
target_id = name.target.id # type: ignore[attr-defined]
|
||||||
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
|
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
|
||||||
|
@ -981,8 +981,11 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
# cond is set in a prior loop iteration below
|
# cond is set in a prior loop iteration below
|
||||||
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
||||||
self.expl_stmts = fail_inner
|
self.expl_stmts = fail_inner
|
||||||
if isinstance(v, ast.Compare):
|
# Check if the left operand is a namedExpr and the value has already been visited
|
||||||
if isinstance(v.left, namedExpr) and (
|
if (
|
||||||
|
isinstance(v, ast.Compare)
|
||||||
|
and isinstance(v.left, namedExpr)
|
||||||
|
and (
|
||||||
v.left.target.id
|
v.left.target.id
|
||||||
in [
|
in [
|
||||||
ast_expr.id
|
ast_expr.id
|
||||||
|
@ -990,16 +993,11 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
if hasattr(ast_expr, "id")
|
if hasattr(ast_expr, "id")
|
||||||
]
|
]
|
||||||
or v.left.target.id == pytest_temp
|
or v.left.target.id == pytest_temp
|
||||||
):
|
)
|
||||||
pytest_temp = f"pytest_{v.left.target.id}_temp"
|
):
|
||||||
self.variables_overwrite[v.left.target.id] = pytest_temp
|
pytest_temp = util.compose_temp_variable(v.left.target.id)
|
||||||
v.left.target.id = pytest_temp
|
self.variables_overwrite[v.left.target.id] = pytest_temp
|
||||||
|
v.left.target.id = pytest_temp
|
||||||
elif isinstance(v.left, ast.Name) and (
|
|
||||||
pytest_temp is not None
|
|
||||||
and v.left.id == pytest_temp.lstrip("pytest_").rstrip("_temp")
|
|
||||||
):
|
|
||||||
v.left.id = pytest_temp
|
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
res, expl = self.visit(v)
|
res, expl = self.visit(v)
|
||||||
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
||||||
|
@ -1075,6 +1073,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
|
|
||||||
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
|
# We first check if we have overwritten a variable in the previous assert
|
||||||
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
|
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
|
||||||
comp.left.id = self.variables_overwrite[comp.left.id]
|
comp.left.id = self.variables_overwrite[comp.left.id]
|
||||||
left_res, left_expl = self.visit(comp.left)
|
left_res, left_expl = self.visit(comp.left)
|
||||||
|
@ -1093,7 +1092,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
and isinstance(left_res, ast.Name)
|
and isinstance(left_res, ast.Name)
|
||||||
and next_operand.target.id == left_res.id
|
and next_operand.target.id == left_res.id
|
||||||
):
|
):
|
||||||
next_operand.target.id = f"pytest_{left_res.id}_temp"
|
next_operand.target.id = util.compose_temp_variable(left_res.id)
|
||||||
self.variables_overwrite[left_res.id] = next_operand.target.id
|
self.variables_overwrite[left_res.id] = next_operand.target.id
|
||||||
next_res, next_expl = self.visit(next_operand)
|
next_res, next_expl = self.visit(next_operand)
|
||||||
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
||||||
|
|
|
@ -520,3 +520,7 @@ def running_on_ci() -> bool:
|
||||||
"""Check if we're currently running on a CI system."""
|
"""Check if we're currently running on a CI system."""
|
||||||
env_vars = ["CI", "BUILD_NUMBER"]
|
env_vars = ["CI", "BUILD_NUMBER"]
|
||||||
return any(var in os.environ for var in env_vars)
|
return any(var in os.environ for var in env_vars)
|
||||||
|
|
||||||
|
|
||||||
|
def compose_temp_variable(original_variable: str) -> str:
|
||||||
|
return f"pytest_{original_variable}_temp"
|
||||||
|
|
Loading…
Reference in New Issue