feat: support keyword arguments in marker expressions

Fixes #12281
This commit is contained in:
lovetheguitar 2024-06-20 23:09:14 +02:00
parent e8fa8dd31c
commit 15c33fbaa3
4 changed files with 304 additions and 17 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import collections
import dataclasses import dataclasses
from typing import AbstractSet from typing import AbstractSet
from typing import Collection from typing import Collection
@ -181,7 +182,9 @@ class KeywordMatcher:
return cls(mapped_names) return cls(mapped_names)
def __call__(self, subname: str) -> bool: def __call__(self, subname: str, /, **kwargs: object) -> bool:
if kwargs:
raise UsageError("Keyword expressions do not support call parameters.")
subname = subname.lower() subname = subname.lower()
names = (name.lower() for name in self._names) names = (name.lower() for name in self._names)
@ -211,6 +214,9 @@ def deselect_by_keyword(items: list[Item], config: Config) -> None:
items[:] = remaining items[:] = remaining
NOT_NONE_SENTINEL = object()
@dataclasses.dataclass @dataclasses.dataclass
class MarkMatcher: class MarkMatcher:
"""A matcher for markers which are present. """A matcher for markers which are present.
@ -218,17 +224,31 @@ class MarkMatcher:
Tries to match on any marker names, attached to the given colitem. Tries to match on any marker names, attached to the given colitem.
""" """
__slots__ = ("own_mark_names",) __slots__ = ("own_mark_name_mapping",)
own_mark_names: AbstractSet[str] own_mark_name_mapping: dict[str, list[Mark]]
@classmethod @classmethod
def from_item(cls, item: Item) -> MarkMatcher: def from_item(cls, item: Item) -> MarkMatcher:
mark_names = {mark.name for mark in item.iter_markers()} mark_name_mapping = collections.defaultdict(list)
return cls(mark_names) for mark in item.iter_markers():
mark_name_mapping[mark.name].append(mark)
return cls(mark_name_mapping)
def __call__(self, name: str) -> bool: def __call__(self, name: str, /, **kwargs: object) -> bool:
return name in self.own_mark_names if not (matches := self.own_mark_name_mapping.get(name, [])):
return False
if not kwargs:
return True
for mark in matches:
if all(
mark.kwargs.get(k, NOT_NONE_SENTINEL) == v for k, v in kwargs.items()
):
return True
return False
def deselect_by_mark(items: list[Item], config: Config) -> None: def deselect_by_mark(items: list[Item], config: Config) -> None:

View File

@ -5,7 +5,8 @@ The grammar is:
expression: expr? EOF expression: expr? EOF
expr: and_expr ('or' and_expr)* expr: and_expr ('or' and_expr)*
and_expr: not_expr ('and' not_expr)* and_expr: not_expr ('and' not_expr)*
not_expr: 'not' not_expr | '(' expr ')' | ident not_expr: 'not' not_expr | '(' expr ')' | ident ( '(' name '=' value ( ', ' name '=' value )* ')')*
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+ ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
The semantics are: The semantics are:
@ -20,12 +21,13 @@ from __future__ import annotations
import ast import ast
import dataclasses import dataclasses
import enum import enum
import keyword
import re import re
import types import types
from typing import Callable
from typing import Iterator from typing import Iterator
from typing import Mapping from typing import Mapping
from typing import NoReturn from typing import NoReturn
from typing import Protocol
from typing import Sequence from typing import Sequence
@ -43,6 +45,9 @@ class TokenType(enum.Enum):
NOT = "not" NOT = "not"
IDENT = "identifier" IDENT = "identifier"
EOF = "end of input" EOF = "end of input"
EQUAL = "="
STRING = "str"
COMMA = ","
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -86,6 +91,27 @@ class Scanner:
elif input[pos] == ")": elif input[pos] == ")":
yield Token(TokenType.RPAREN, ")", pos) yield Token(TokenType.RPAREN, ")", pos)
pos += 1 pos += 1
elif input[pos] == "=":
yield Token(TokenType.EQUAL, "=", pos)
pos += 1
elif input[pos] == ",":
yield Token(TokenType.COMMA, ",", pos)
pos += 1
elif (quote_char := input[pos]) == "'" or input[pos] == '"':
quote_position = input[pos + 1 :].find(quote_char)
if quote_position == -1:
raise ParseError(
pos + 1,
f'closing quote "{quote_char}" is missing',
)
value = input[pos : pos + 2 + quote_position]
if "\\" in value:
raise ParseError(
pos + 1,
"escaping not supported in marker expression",
)
yield Token(TokenType.STRING, value, pos)
pos += len(value)
else: else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:]) match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
if match: if match:
@ -166,18 +192,84 @@ def not_expr(s: Scanner) -> ast.expr:
return ret return ret
ident = s.accept(TokenType.IDENT) ident = s.accept(TokenType.IDENT)
if ident: if ident:
return ast.Name(IDENT_PREFIX + ident.value, ast.Load()) name = ast.Name(IDENT_PREFIX + ident.value, ast.Load())
if s.accept(TokenType.LPAREN):
ret = ast.Call(func=name, args=[], keywords=all_kwargs(s))
s.accept(TokenType.RPAREN, reject=True)
else:
ret = name
return ret
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT)) s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
class MatcherAdapter(Mapping[str, bool]): BUILTIN_MATCHERS = {"True": True, "False": False, "None": None}
def single_kwarg(s: Scanner) -> ast.keyword:
keyword_name = s.accept(TokenType.IDENT, reject=True)
assert keyword_name is not None # for mypy
if not keyword_name.value.isidentifier() or keyword.iskeyword(keyword_name.value):
raise ParseError(
keyword_name.pos + 1,
f'unexpected character/s "{keyword_name.value}"',
)
s.accept(TokenType.EQUAL, reject=True)
if value_token := s.accept(TokenType.STRING):
value: str | int | bool | None = value_token.value[1:-1] # strip quotes
else:
value_token = s.accept(TokenType.IDENT, reject=True)
assert value_token is not None # for mypy
if (
(number := value_token.value).isdigit()
or number.startswith("-")
and number[1:].isdigit()
):
value = int(number)
elif value_token.value in BUILTIN_MATCHERS:
value = BUILTIN_MATCHERS[value_token.value]
else:
raise ParseError(
value_token.pos + 1,
f'unexpected character/s "{value_token.value}"',
)
ret = ast.keyword(keyword_name.value, ast.Constant(value))
return ret
def all_kwargs(s: Scanner) -> list[ast.keyword]:
ret = [single_kwarg(s)]
while s.accept(TokenType.COMMA):
ret.append(single_kwarg(s))
return ret
class MatcherCall(Protocol):
def __call__(self, name: str, /, **kwargs: object) -> bool: ...
@dataclasses.dataclass
class MatcherNameAdapter:
matcher: MatcherCall
name: str
def __bool__(self) -> bool:
return self.matcher(self.name)
def __call__(self, **kwargs: object) -> bool:
return self.matcher(self.name, **kwargs)
class MatcherAdapter(Mapping[str, MatcherNameAdapter]):
"""Adapts a matcher function to a locals mapping as required by eval().""" """Adapts a matcher function to a locals mapping as required by eval()."""
def __init__(self, matcher: Callable[[str], bool]) -> None: def __init__(self, matcher: MatcherCall) -> None:
self.matcher = matcher self.matcher = matcher
def __getitem__(self, key: str) -> bool: def __getitem__(self, key: str) -> MatcherNameAdapter:
return self.matcher(key[len(IDENT_PREFIX) :]) return MatcherNameAdapter(matcher=self.matcher, name=key[len(IDENT_PREFIX) :])
def __iter__(self) -> Iterator[str]: def __iter__(self) -> Iterator[str]:
raise NotImplementedError() raise NotImplementedError()
@ -211,7 +303,7 @@ class Expression:
) )
return Expression(code) return Expression(code)
def evaluate(self, matcher: Callable[[str], bool]) -> bool: def evaluate(self, matcher: MatcherCall) -> bool:
"""Evaluate the match expression. """Evaluate the match expression.
:param matcher: :param matcher:
@ -220,5 +312,5 @@ class Expression:
:returns: Whether the expression matches or not. :returns: Whether the expression matches or not.
""" """
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)) ret: bool = bool(eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)))
return ret return ret

View File

@ -233,6 +233,54 @@ def test_mark_option(
assert passed_str == expected_passed assert passed_str == expected_passed
@pytest.mark.parametrize(
("expr", "expected_passed"),
[ # TODO: improve/sort out
("car(color='red')", ["test_one"]),
("car(color='red') or car(color='blue')", ["test_one", "test_two"]),
("car and not car(temp=5)", ["test_one", "test_three"]),
("car(temp=4)", ["test_one"]),
("car(temp=4) or car(temp=5)", ["test_one", "test_two"]),
("car(temp=4) and car(temp=5)", []),
("car(temp=-5)", ["test_three"]),
("car(ac=True)", ["test_one"]),
("car(ac=False)", ["test_two"]),
("car(ac=None)", ["test_three"]), # test NOT_NONE_SENTINEL
],
ids=str,
)
def test_mark_option_with_kwargs(
expr: str, expected_passed: list[str | None], pytester: Pytester
) -> None:
pytester.makepyfile(
"""
import pytest
@pytest.mark.car
@pytest.mark.car(ac=True)
@pytest.mark.car(temp=4)
@pytest.mark.car(color="red")
def test_one():
pass
@pytest.mark.car
@pytest.mark.car(ac=False)
@pytest.mark.car(temp=5)
@pytest.mark.car(color="blue")
def test_two():
pass
@pytest.mark.car
@pytest.mark.car(ac=None)
@pytest.mark.car(temp=-5)
def test_three():
pass
"""
)
rec = pytester.inline_run("-m", expr)
passed, skipped, fail = rec.listoutcomes()
passed_str = [x.nodeid.split("::")[-1] for x in passed]
assert passed_str == expected_passed
@pytest.mark.parametrize( @pytest.mark.parametrize(
("expr", "expected_passed"), ("expr", "expected_passed"),
[("interface", ["test_interface"]), ("not interface", ["test_nointer"])], [("interface", ["test_interface"]), ("not interface", ["test_nointer"])],

View File

@ -1,14 +1,19 @@
from __future__ import annotations from __future__ import annotations
import collections
from typing import Callable from typing import Callable
from typing import cast
from _pytest.mark import MarkMatcher
from _pytest.mark import structures
from _pytest.mark.expression import Expression from _pytest.mark.expression import Expression
from _pytest.mark.expression import MatcherCall
from _pytest.mark.expression import ParseError from _pytest.mark.expression import ParseError
import pytest import pytest
def evaluate(input: str, matcher: Callable[[str], bool]) -> bool: def evaluate(input: str, matcher: Callable[[str], bool]) -> bool:
return Expression.compile(input).evaluate(matcher) return Expression.compile(input).evaluate(cast(MatcherCall, matcher))
def test_empty_is_false() -> None: def test_empty_is_false() -> None:
@ -153,6 +158,8 @@ def test_syntax_errors(expr: str, column: int, message: str) -> None:
"1234", "1234",
"1234abcd", "1234abcd",
"1234and", "1234and",
"1234or",
"1234not",
"notandor", "notandor",
"not_and_or", "not_and_or",
"not[and]or", "not[and]or",
@ -195,3 +202,123 @@ def test_valid_idents(ident: str) -> None:
def test_invalid_idents(ident: str) -> None: def test_invalid_idents(ident: str) -> None:
with pytest.raises(ParseError): with pytest.raises(ParseError):
evaluate(ident, lambda ident: True) evaluate(ident, lambda ident: True)
@pytest.mark.parametrize(
"expr, expected_error_msg",
(
("mark(1=2)", 'unexpected character/s "1"'),
("mark(/=2)", 'unexpected character/s "/"'),
("mark(True=False)", 'unexpected character/s "True"'),
("mark(def=False)", 'unexpected character/s "def"'),
("mark(class=False)", 'unexpected character/s "class"'),
("mark(if=False)", 'unexpected character/s "if"'),
("mark(else=False)", 'unexpected character/s "else"'),
("mark(1)", 'unexpected character/s "1"'),
("mark(var:=False", 'unexpected character/s "var:"'),
("mark(valid=False, def=1)", 'unexpected character/s "def"'),
("mark(var==", "expected identifier; got ="),
("mark(var=none)", 'unexpected character/s "none"'),
("mark(var=1.1)", 'unexpected character/s "1.1"'),
("mark(var)", "expected =; got right parenthesis"),
("mark(var=')", """closing quote "'" is missing"""),
('mark(var=")', 'closing quote """ is missing'),
("""mark(var="')""", 'closing quote """ is missing'),
("""mark(var='")""", """closing quote "'" is missing"""),
(r"mark(var='\hugo')", "escaping not supported in marker expression"),
),
)
def test_invalid_kwarg_name_or_value( # TODO: move to `test_syntax_errors` ?
expr: str, expected_error_msg: str, mark_matcher: MarkMatcher
) -> None:
with pytest.raises(ParseError, match=expected_error_msg):
assert evaluate(expr, mark_matcher)
@pytest.fixture(scope="session")
def mark_matcher() -> MarkMatcher:
markers = []
mark_name_mapping = collections.defaultdict(list)
def create_marker(name: str, kwargs: dict[str, object]) -> structures.Mark:
return structures.Mark(name=name, args=tuple(), kwargs=kwargs, _ispytest=True)
markers.append(create_marker("number_mark", {"a": 1, "b": 2, "c": 3, "d": 999_999}))
markers.append(
create_marker("builtin_matchers_mark", {"x": True, "y": False, "z": None})
)
markers.append(
create_marker(
"str_mark",
{"m": "M", "space": "with space", "aaאבגדcc": "aaאבגדcc", "אבגד": "אבגד"},
)
)
for marker in markers:
mark_name_mapping[marker.name].append(marker)
return MarkMatcher(mark_name_mapping)
@pytest.mark.parametrize(
"expr, expected",
(
# happy cases
("number_mark(a=1)", True),
("number_mark(b=2)", True),
("number_mark(a=1,b=2)", True),
("number_mark(a=1, b=2)", True),
("number_mark(d=999999)", True),
("number_mark(a = 1,b= 2, c = 3)", True),
# sad cases
("number_mark(a=6)", False),
("number_mark(b=6)", False),
("number_mark(a=1,b=6)", False),
("number_mark(a=6,b=2)", False),
("number_mark(a = 1,b= 2, c = 6)", False),
("number_mark(a='1')", False),
),
)
def test_keyword_expressions_with_numbers(
expr: str, expected: bool, mark_matcher: MarkMatcher
) -> None:
assert evaluate(expr, mark_matcher) is expected
@pytest.mark.parametrize(
"expr, expected",
(
("builtin_matchers_mark(x=True)", True),
("builtin_matchers_mark(x=False)", False),
("builtin_matchers_mark(y=True)", False),
("builtin_matchers_mark(y=False)", True),
("builtin_matchers_mark(z=None)", True),
("builtin_matchers_mark(z=False)", False),
("builtin_matchers_mark(z=True)", False),
("builtin_matchers_mark(z=0)", False),
("builtin_matchers_mark(z=1)", False),
),
)
def test_builtin_matchers_keyword_expressions( # TODO: naming when decided
expr: str, expected: bool, mark_matcher: MarkMatcher
) -> None:
assert evaluate(expr, mark_matcher) is expected
@pytest.mark.parametrize(
"expr, expected",
(
("str_mark(m='M')", True),
('str_mark(m="M")', True),
("str_mark(aaאבגדcc='aaאבגדcc')", True),
("str_mark(אבגד='אבגד')", True),
("str_mark(space='with space')", True),
("str_mark(m='wrong')", False),
("str_mark(aaאבגדcc='wrong')", False),
("str_mark(אבגד='wrong')", False),
),
)
def test_str_keyword_expressions(
expr: str, expected: bool, mark_matcher: MarkMatcher
) -> None:
assert evaluate(expr, mark_matcher) is expected