feat: support keyword arguments in marker expressions

Fixes #12281
This commit is contained in:
lovetheguitar 2024-06-20 23:09:14 +02:00
parent e8fa8dd31c
commit 15c33fbaa3
4 changed files with 304 additions and 17 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import collections
import dataclasses
from typing import AbstractSet
from typing import Collection
@ -181,7 +182,9 @@ class KeywordMatcher:
return cls(mapped_names)
def __call__(self, subname: str) -> bool:
def __call__(self, subname: str, /, **kwargs: object) -> bool:
if kwargs:
raise UsageError("Keyword expressions do not support call parameters.")
subname = subname.lower()
names = (name.lower() for name in self._names)
@ -211,6 +214,9 @@ def deselect_by_keyword(items: list[Item], config: Config) -> None:
items[:] = remaining
NOT_NONE_SENTINEL = object()
@dataclasses.dataclass
class MarkMatcher:
"""A matcher for markers which are present.
@ -218,17 +224,31 @@ class MarkMatcher:
Tries to match on any marker names, attached to the given colitem.
"""
__slots__ = ("own_mark_names",)
__slots__ = ("own_mark_name_mapping",)
own_mark_names: AbstractSet[str]
own_mark_name_mapping: dict[str, list[Mark]]
@classmethod
def from_item(cls, item: Item) -> MarkMatcher:
mark_names = {mark.name for mark in item.iter_markers()}
return cls(mark_names)
mark_name_mapping = collections.defaultdict(list)
for mark in item.iter_markers():
mark_name_mapping[mark.name].append(mark)
return cls(mark_name_mapping)
def __call__(self, name: str) -> bool:
return name in self.own_mark_names
def __call__(self, name: str, /, **kwargs: object) -> bool:
if not (matches := self.own_mark_name_mapping.get(name, [])):
return False
if not kwargs:
return True
for mark in matches:
if all(
mark.kwargs.get(k, NOT_NONE_SENTINEL) == v for k, v in kwargs.items()
):
return True
return False
def deselect_by_mark(items: list[Item], config: Config) -> None:

View File

@ -5,7 +5,8 @@ The grammar is:
expression: expr? EOF
expr: and_expr ('or' and_expr)*
and_expr: not_expr ('and' not_expr)*
not_expr: 'not' not_expr | '(' expr ')' | ident
not_expr: 'not' not_expr | '(' expr ')' | ident ( '(' name '=' value ( ', ' name '=' value )* ')')*
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
The semantics are:
@ -20,12 +21,13 @@ from __future__ import annotations
import ast
import dataclasses
import enum
import keyword
import re
import types
from typing import Callable
from typing import Iterator
from typing import Mapping
from typing import NoReturn
from typing import Protocol
from typing import Sequence
@ -43,6 +45,9 @@ class TokenType(enum.Enum):
NOT = "not"
IDENT = "identifier"
EOF = "end of input"
EQUAL = "="
STRING = "str"
COMMA = ","
@dataclasses.dataclass(frozen=True)
@ -86,6 +91,27 @@ class Scanner:
elif input[pos] == ")":
yield Token(TokenType.RPAREN, ")", pos)
pos += 1
elif input[pos] == "=":
yield Token(TokenType.EQUAL, "=", pos)
pos += 1
elif input[pos] == ",":
yield Token(TokenType.COMMA, ",", pos)
pos += 1
elif (quote_char := input[pos]) == "'" or input[pos] == '"':
quote_position = input[pos + 1 :].find(quote_char)
if quote_position == -1:
raise ParseError(
pos + 1,
f'closing quote "{quote_char}" is missing',
)
value = input[pos : pos + 2 + quote_position]
if "\\" in value:
raise ParseError(
pos + 1,
"escaping not supported in marker expression",
)
yield Token(TokenType.STRING, value, pos)
pos += len(value)
else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
if match:
@ -166,18 +192,84 @@ def not_expr(s: Scanner) -> ast.expr:
return ret
ident = s.accept(TokenType.IDENT)
if ident:
return ast.Name(IDENT_PREFIX + ident.value, ast.Load())
name = ast.Name(IDENT_PREFIX + ident.value, ast.Load())
if s.accept(TokenType.LPAREN):
ret = ast.Call(func=name, args=[], keywords=all_kwargs(s))
s.accept(TokenType.RPAREN, reject=True)
else:
ret = name
return ret
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
class MatcherAdapter(Mapping[str, bool]):
BUILTIN_MATCHERS = {"True": True, "False": False, "None": None}
def single_kwarg(s: Scanner) -> ast.keyword:
keyword_name = s.accept(TokenType.IDENT, reject=True)
assert keyword_name is not None # for mypy
if not keyword_name.value.isidentifier() or keyword.iskeyword(keyword_name.value):
raise ParseError(
keyword_name.pos + 1,
f'unexpected character/s "{keyword_name.value}"',
)
s.accept(TokenType.EQUAL, reject=True)
if value_token := s.accept(TokenType.STRING):
value: str | int | bool | None = value_token.value[1:-1] # strip quotes
else:
value_token = s.accept(TokenType.IDENT, reject=True)
assert value_token is not None # for mypy
if (
(number := value_token.value).isdigit()
or number.startswith("-")
and number[1:].isdigit()
):
value = int(number)
elif value_token.value in BUILTIN_MATCHERS:
value = BUILTIN_MATCHERS[value_token.value]
else:
raise ParseError(
value_token.pos + 1,
f'unexpected character/s "{value_token.value}"',
)
ret = ast.keyword(keyword_name.value, ast.Constant(value))
return ret
def all_kwargs(s: Scanner) -> list[ast.keyword]:
ret = [single_kwarg(s)]
while s.accept(TokenType.COMMA):
ret.append(single_kwarg(s))
return ret
class MatcherCall(Protocol):
def __call__(self, name: str, /, **kwargs: object) -> bool: ...
@dataclasses.dataclass
class MatcherNameAdapter:
matcher: MatcherCall
name: str
def __bool__(self) -> bool:
return self.matcher(self.name)
def __call__(self, **kwargs: object) -> bool:
return self.matcher(self.name, **kwargs)
class MatcherAdapter(Mapping[str, MatcherNameAdapter]):
"""Adapts a matcher function to a locals mapping as required by eval()."""
def __init__(self, matcher: Callable[[str], bool]) -> None:
def __init__(self, matcher: MatcherCall) -> None:
self.matcher = matcher
def __getitem__(self, key: str) -> bool:
return self.matcher(key[len(IDENT_PREFIX) :])
def __getitem__(self, key: str) -> MatcherNameAdapter:
return MatcherNameAdapter(matcher=self.matcher, name=key[len(IDENT_PREFIX) :])
def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
@ -211,7 +303,7 @@ class Expression:
)
return Expression(code)
def evaluate(self, matcher: Callable[[str], bool]) -> bool:
def evaluate(self, matcher: MatcherCall) -> bool:
"""Evaluate the match expression.
:param matcher:
@ -220,5 +312,5 @@ class Expression:
:returns: Whether the expression matches or not.
"""
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
ret: bool = bool(eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)))
return ret

View File

@ -233,6 +233,54 @@ def test_mark_option(
assert passed_str == expected_passed
@pytest.mark.parametrize(
("expr", "expected_passed"),
[ # TODO: improve/sort out
("car(color='red')", ["test_one"]),
("car(color='red') or car(color='blue')", ["test_one", "test_two"]),
("car and not car(temp=5)", ["test_one", "test_three"]),
("car(temp=4)", ["test_one"]),
("car(temp=4) or car(temp=5)", ["test_one", "test_two"]),
("car(temp=4) and car(temp=5)", []),
("car(temp=-5)", ["test_three"]),
("car(ac=True)", ["test_one"]),
("car(ac=False)", ["test_two"]),
("car(ac=None)", ["test_three"]), # test NOT_NONE_SENTINEL
],
ids=str,
)
def test_mark_option_with_kwargs(
expr: str, expected_passed: list[str | None], pytester: Pytester
) -> None:
pytester.makepyfile(
"""
import pytest
@pytest.mark.car
@pytest.mark.car(ac=True)
@pytest.mark.car(temp=4)
@pytest.mark.car(color="red")
def test_one():
pass
@pytest.mark.car
@pytest.mark.car(ac=False)
@pytest.mark.car(temp=5)
@pytest.mark.car(color="blue")
def test_two():
pass
@pytest.mark.car
@pytest.mark.car(ac=None)
@pytest.mark.car(temp=-5)
def test_three():
pass
"""
)
rec = pytester.inline_run("-m", expr)
passed, skipped, fail = rec.listoutcomes()
passed_str = [x.nodeid.split("::")[-1] for x in passed]
assert passed_str == expected_passed
@pytest.mark.parametrize(
("expr", "expected_passed"),
[("interface", ["test_interface"]), ("not interface", ["test_nointer"])],

View File

@ -1,14 +1,19 @@
from __future__ import annotations
import collections
from typing import Callable
from typing import cast
from _pytest.mark import MarkMatcher
from _pytest.mark import structures
from _pytest.mark.expression import Expression
from _pytest.mark.expression import MatcherCall
from _pytest.mark.expression import ParseError
import pytest
def evaluate(input: str, matcher: Callable[[str], bool]) -> bool:
return Expression.compile(input).evaluate(matcher)
return Expression.compile(input).evaluate(cast(MatcherCall, matcher))
def test_empty_is_false() -> None:
@ -153,6 +158,8 @@ def test_syntax_errors(expr: str, column: int, message: str) -> None:
"1234",
"1234abcd",
"1234and",
"1234or",
"1234not",
"notandor",
"not_and_or",
"not[and]or",
@ -195,3 +202,123 @@ def test_valid_idents(ident: str) -> None:
def test_invalid_idents(ident: str) -> None:
with pytest.raises(ParseError):
evaluate(ident, lambda ident: True)
@pytest.mark.parametrize(
"expr, expected_error_msg",
(
("mark(1=2)", 'unexpected character/s "1"'),
("mark(/=2)", 'unexpected character/s "/"'),
("mark(True=False)", 'unexpected character/s "True"'),
("mark(def=False)", 'unexpected character/s "def"'),
("mark(class=False)", 'unexpected character/s "class"'),
("mark(if=False)", 'unexpected character/s "if"'),
("mark(else=False)", 'unexpected character/s "else"'),
("mark(1)", 'unexpected character/s "1"'),
("mark(var:=False", 'unexpected character/s "var:"'),
("mark(valid=False, def=1)", 'unexpected character/s "def"'),
("mark(var==", "expected identifier; got ="),
("mark(var=none)", 'unexpected character/s "none"'),
("mark(var=1.1)", 'unexpected character/s "1.1"'),
("mark(var)", "expected =; got right parenthesis"),
("mark(var=')", """closing quote "'" is missing"""),
('mark(var=")', 'closing quote """ is missing'),
("""mark(var="')""", 'closing quote """ is missing'),
("""mark(var='")""", """closing quote "'" is missing"""),
(r"mark(var='\hugo')", "escaping not supported in marker expression"),
),
)
def test_invalid_kwarg_name_or_value( # TODO: move to `test_syntax_errors` ?
expr: str, expected_error_msg: str, mark_matcher: MarkMatcher
) -> None:
with pytest.raises(ParseError, match=expected_error_msg):
assert evaluate(expr, mark_matcher)
@pytest.fixture(scope="session")
def mark_matcher() -> MarkMatcher:
markers = []
mark_name_mapping = collections.defaultdict(list)
def create_marker(name: str, kwargs: dict[str, object]) -> structures.Mark:
return structures.Mark(name=name, args=tuple(), kwargs=kwargs, _ispytest=True)
markers.append(create_marker("number_mark", {"a": 1, "b": 2, "c": 3, "d": 999_999}))
markers.append(
create_marker("builtin_matchers_mark", {"x": True, "y": False, "z": None})
)
markers.append(
create_marker(
"str_mark",
{"m": "M", "space": "with space", "aaאבגדcc": "aaאבגדcc", "אבגד": "אבגד"},
)
)
for marker in markers:
mark_name_mapping[marker.name].append(marker)
return MarkMatcher(mark_name_mapping)
@pytest.mark.parametrize(
"expr, expected",
(
# happy cases
("number_mark(a=1)", True),
("number_mark(b=2)", True),
("number_mark(a=1,b=2)", True),
("number_mark(a=1, b=2)", True),
("number_mark(d=999999)", True),
("number_mark(a = 1,b= 2, c = 3)", True),
# sad cases
("number_mark(a=6)", False),
("number_mark(b=6)", False),
("number_mark(a=1,b=6)", False),
("number_mark(a=6,b=2)", False),
("number_mark(a = 1,b= 2, c = 6)", False),
("number_mark(a='1')", False),
),
)
def test_keyword_expressions_with_numbers(
expr: str, expected: bool, mark_matcher: MarkMatcher
) -> None:
assert evaluate(expr, mark_matcher) is expected
@pytest.mark.parametrize(
"expr, expected",
(
("builtin_matchers_mark(x=True)", True),
("builtin_matchers_mark(x=False)", False),
("builtin_matchers_mark(y=True)", False),
("builtin_matchers_mark(y=False)", True),
("builtin_matchers_mark(z=None)", True),
("builtin_matchers_mark(z=False)", False),
("builtin_matchers_mark(z=True)", False),
("builtin_matchers_mark(z=0)", False),
("builtin_matchers_mark(z=1)", False),
),
)
def test_builtin_matchers_keyword_expressions( # TODO: naming when decided
expr: str, expected: bool, mark_matcher: MarkMatcher
) -> None:
assert evaluate(expr, mark_matcher) is expected
@pytest.mark.parametrize(
"expr, expected",
(
("str_mark(m='M')", True),
('str_mark(m="M")', True),
("str_mark(aaאבגדcc='aaאבגדcc')", True),
("str_mark(אבגד='אבגד')", True),
("str_mark(space='with space')", True),
("str_mark(m='wrong')", False),
("str_mark(aaאבגדcc='wrong')", False),
("str_mark(אבגד='wrong')", False),
),
)
def test_str_keyword_expressions(
expr: str, expected: bool, mark_matcher: MarkMatcher
) -> None:
assert evaluate(expr, mark_matcher) is expected