diff --git a/src/_pytest/mark/__init__.py b/src/_pytest/mark/__init__.py index b8a309215..ae6940a62 100644 --- a/src/_pytest/mark/__init__.py +++ b/src/_pytest/mark/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import collections import dataclasses from typing import AbstractSet from typing import Collection @@ -181,7 +182,9 @@ class KeywordMatcher: return cls(mapped_names) - def __call__(self, subname: str) -> bool: + def __call__(self, subname: str, /, **kwargs: object) -> bool: + if kwargs: + raise UsageError("Keyword expressions do not support call parameters.") subname = subname.lower() names = (name.lower() for name in self._names) @@ -211,6 +214,9 @@ def deselect_by_keyword(items: list[Item], config: Config) -> None: items[:] = remaining +NOT_NONE_SENTINEL = object() + + @dataclasses.dataclass class MarkMatcher: """A matcher for markers which are present. @@ -218,17 +224,31 @@ class MarkMatcher: Tries to match on any marker names, attached to the given colitem. """ - __slots__ = ("own_mark_names",) + __slots__ = ("own_mark_name_mapping",) - own_mark_names: AbstractSet[str] + own_mark_name_mapping: dict[str, list[Mark]] @classmethod def from_item(cls, item: Item) -> MarkMatcher: - mark_names = {mark.name for mark in item.iter_markers()} - return cls(mark_names) + mark_name_mapping = collections.defaultdict(list) + for mark in item.iter_markers(): + mark_name_mapping[mark.name].append(mark) + return cls(mark_name_mapping) - def __call__(self, name: str) -> bool: - return name in self.own_mark_names + def __call__(self, name: str, /, **kwargs: object) -> bool: + if not (matches := self.own_mark_name_mapping.get(name, [])): + return False + + if not kwargs: + return True + + for mark in matches: + if all( + mark.kwargs.get(k, NOT_NONE_SENTINEL) == v for k, v in kwargs.items() + ): + return True + + return False def deselect_by_mark(items: list[Item], config: Config) -> None: diff --git a/src/_pytest/mark/expression.py b/src/_pytest/mark/expression.py index e65b02858..16883c6b7 100644 --- a/src/_pytest/mark/expression.py +++ b/src/_pytest/mark/expression.py @@ -5,7 +5,8 @@ The grammar is: expression: expr? EOF expr: and_expr ('or' and_expr)* and_expr: not_expr ('and' not_expr)* -not_expr: 'not' not_expr | '(' expr ')' | ident +not_expr: 'not' not_expr | '(' expr ')' | ident ( '(' name '=' value ( ', ' name '=' value )* ')')* + ident: (\w|:|\+|-|\.|\[|\]|\\|/)+ The semantics are: @@ -20,12 +21,13 @@ from __future__ import annotations import ast import dataclasses import enum +import keyword import re import types -from typing import Callable from typing import Iterator from typing import Mapping from typing import NoReturn +from typing import Protocol from typing import Sequence @@ -43,6 +45,9 @@ class TokenType(enum.Enum): NOT = "not" IDENT = "identifier" EOF = "end of input" + EQUAL = "=" + STRING = "str" + COMMA = "," @dataclasses.dataclass(frozen=True) @@ -86,6 +91,27 @@ class Scanner: elif input[pos] == ")": yield Token(TokenType.RPAREN, ")", pos) pos += 1 + elif input[pos] == "=": + yield Token(TokenType.EQUAL, "=", pos) + pos += 1 + elif input[pos] == ",": + yield Token(TokenType.COMMA, ",", pos) + pos += 1 + elif (quote_char := input[pos]) == "'" or input[pos] == '"': + quote_position = input[pos + 1 :].find(quote_char) + if quote_position == -1: + raise ParseError( + pos + 1, + f'closing quote "{quote_char}" is missing', + ) + value = input[pos : pos + 2 + quote_position] + if "\\" in value: + raise ParseError( + pos + 1, + "escaping not supported in marker expression", + ) + yield Token(TokenType.STRING, value, pos) + pos += len(value) else: match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:]) if match: @@ -166,18 +192,84 @@ def not_expr(s: Scanner) -> ast.expr: return ret ident = s.accept(TokenType.IDENT) if ident: - return ast.Name(IDENT_PREFIX + ident.value, ast.Load()) + name = ast.Name(IDENT_PREFIX + ident.value, ast.Load()) + if s.accept(TokenType.LPAREN): + ret = ast.Call(func=name, args=[], keywords=all_kwargs(s)) + s.accept(TokenType.RPAREN, reject=True) + else: + ret = name + return ret + s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT)) -class MatcherAdapter(Mapping[str, bool]): +BUILTIN_MATCHERS = {"True": True, "False": False, "None": None} + + +def single_kwarg(s: Scanner) -> ast.keyword: + keyword_name = s.accept(TokenType.IDENT, reject=True) + assert keyword_name is not None # for mypy + if not keyword_name.value.isidentifier() or keyword.iskeyword(keyword_name.value): + raise ParseError( + keyword_name.pos + 1, + f'unexpected character/s "{keyword_name.value}"', + ) + s.accept(TokenType.EQUAL, reject=True) + + if value_token := s.accept(TokenType.STRING): + value: str | int | bool | None = value_token.value[1:-1] # strip quotes + else: + value_token = s.accept(TokenType.IDENT, reject=True) + assert value_token is not None # for mypy + if ( + (number := value_token.value).isdigit() + or number.startswith("-") + and number[1:].isdigit() + ): + value = int(number) + elif value_token.value in BUILTIN_MATCHERS: + value = BUILTIN_MATCHERS[value_token.value] + else: + raise ParseError( + value_token.pos + 1, + f'unexpected character/s "{value_token.value}"', + ) + + ret = ast.keyword(keyword_name.value, ast.Constant(value)) + return ret + + +def all_kwargs(s: Scanner) -> list[ast.keyword]: + ret = [single_kwarg(s)] + while s.accept(TokenType.COMMA): + ret.append(single_kwarg(s)) + return ret + + +class MatcherCall(Protocol): + def __call__(self, name: str, /, **kwargs: object) -> bool: ... + + +@dataclasses.dataclass +class MatcherNameAdapter: + matcher: MatcherCall + name: str + + def __bool__(self) -> bool: + return self.matcher(self.name) + + def __call__(self, **kwargs: object) -> bool: + return self.matcher(self.name, **kwargs) + + +class MatcherAdapter(Mapping[str, MatcherNameAdapter]): """Adapts a matcher function to a locals mapping as required by eval().""" - def __init__(self, matcher: Callable[[str], bool]) -> None: + def __init__(self, matcher: MatcherCall) -> None: self.matcher = matcher - def __getitem__(self, key: str) -> bool: - return self.matcher(key[len(IDENT_PREFIX) :]) + def __getitem__(self, key: str) -> MatcherNameAdapter: + return MatcherNameAdapter(matcher=self.matcher, name=key[len(IDENT_PREFIX) :]) def __iter__(self) -> Iterator[str]: raise NotImplementedError() @@ -211,7 +303,7 @@ class Expression: ) return Expression(code) - def evaluate(self, matcher: Callable[[str], bool]) -> bool: + def evaluate(self, matcher: MatcherCall) -> bool: """Evaluate the match expression. :param matcher: @@ -220,5 +312,5 @@ class Expression: :returns: Whether the expression matches or not. """ - ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)) + ret: bool = bool(eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))) return ret diff --git a/testing/test_mark.py b/testing/test_mark.py index 090e10ee9..721bb71d3 100644 --- a/testing/test_mark.py +++ b/testing/test_mark.py @@ -233,6 +233,54 @@ def test_mark_option( assert passed_str == expected_passed +@pytest.mark.parametrize( + ("expr", "expected_passed"), + [ # TODO: improve/sort out + ("car(color='red')", ["test_one"]), + ("car(color='red') or car(color='blue')", ["test_one", "test_two"]), + ("car and not car(temp=5)", ["test_one", "test_three"]), + ("car(temp=4)", ["test_one"]), + ("car(temp=4) or car(temp=5)", ["test_one", "test_two"]), + ("car(temp=4) and car(temp=5)", []), + ("car(temp=-5)", ["test_three"]), + ("car(ac=True)", ["test_one"]), + ("car(ac=False)", ["test_two"]), + ("car(ac=None)", ["test_three"]), # test NOT_NONE_SENTINEL + ], + ids=str, +) +def test_mark_option_with_kwargs( + expr: str, expected_passed: list[str | None], pytester: Pytester +) -> None: + pytester.makepyfile( + """ + import pytest + @pytest.mark.car + @pytest.mark.car(ac=True) + @pytest.mark.car(temp=4) + @pytest.mark.car(color="red") + def test_one(): + pass + @pytest.mark.car + @pytest.mark.car(ac=False) + @pytest.mark.car(temp=5) + @pytest.mark.car(color="blue") + def test_two(): + pass + @pytest.mark.car + @pytest.mark.car(ac=None) + @pytest.mark.car(temp=-5) + def test_three(): + pass + + """ + ) + rec = pytester.inline_run("-m", expr) + passed, skipped, fail = rec.listoutcomes() + passed_str = [x.nodeid.split("::")[-1] for x in passed] + assert passed_str == expected_passed + + @pytest.mark.parametrize( ("expr", "expected_passed"), [("interface", ["test_interface"]), ("not interface", ["test_nointer"])], diff --git a/testing/test_mark_expression.py b/testing/test_mark_expression.py index 5bce004cb..0c1e73809 100644 --- a/testing/test_mark_expression.py +++ b/testing/test_mark_expression.py @@ -1,14 +1,19 @@ from __future__ import annotations +import collections from typing import Callable +from typing import cast +from _pytest.mark import MarkMatcher +from _pytest.mark import structures from _pytest.mark.expression import Expression +from _pytest.mark.expression import MatcherCall from _pytest.mark.expression import ParseError import pytest def evaluate(input: str, matcher: Callable[[str], bool]) -> bool: - return Expression.compile(input).evaluate(matcher) + return Expression.compile(input).evaluate(cast(MatcherCall, matcher)) def test_empty_is_false() -> None: @@ -153,6 +158,8 @@ def test_syntax_errors(expr: str, column: int, message: str) -> None: "1234", "1234abcd", "1234and", + "1234or", + "1234not", "notandor", "not_and_or", "not[and]or", @@ -195,3 +202,123 @@ def test_valid_idents(ident: str) -> None: def test_invalid_idents(ident: str) -> None: with pytest.raises(ParseError): evaluate(ident, lambda ident: True) + + +@pytest.mark.parametrize( + "expr, expected_error_msg", + ( + ("mark(1=2)", 'unexpected character/s "1"'), + ("mark(/=2)", 'unexpected character/s "/"'), + ("mark(True=False)", 'unexpected character/s "True"'), + ("mark(def=False)", 'unexpected character/s "def"'), + ("mark(class=False)", 'unexpected character/s "class"'), + ("mark(if=False)", 'unexpected character/s "if"'), + ("mark(else=False)", 'unexpected character/s "else"'), + ("mark(1)", 'unexpected character/s "1"'), + ("mark(var:=False", 'unexpected character/s "var:"'), + ("mark(valid=False, def=1)", 'unexpected character/s "def"'), + ("mark(var==", "expected identifier; got ="), + ("mark(var=none)", 'unexpected character/s "none"'), + ("mark(var=1.1)", 'unexpected character/s "1.1"'), + ("mark(var)", "expected =; got right parenthesis"), + ("mark(var=')", """closing quote "'" is missing"""), + ('mark(var=")', 'closing quote """ is missing'), + ("""mark(var="')""", 'closing quote """ is missing'), + ("""mark(var='")""", """closing quote "'" is missing"""), + (r"mark(var='\hugo')", "escaping not supported in marker expression"), + ), +) +def test_invalid_kwarg_name_or_value( # TODO: move to `test_syntax_errors` ? + expr: str, expected_error_msg: str, mark_matcher: MarkMatcher +) -> None: + with pytest.raises(ParseError, match=expected_error_msg): + assert evaluate(expr, mark_matcher) + + +@pytest.fixture(scope="session") +def mark_matcher() -> MarkMatcher: + markers = [] + mark_name_mapping = collections.defaultdict(list) + + def create_marker(name: str, kwargs: dict[str, object]) -> structures.Mark: + return structures.Mark(name=name, args=tuple(), kwargs=kwargs, _ispytest=True) + + markers.append(create_marker("number_mark", {"a": 1, "b": 2, "c": 3, "d": 999_999})) + markers.append( + create_marker("builtin_matchers_mark", {"x": True, "y": False, "z": None}) + ) + markers.append( + create_marker( + "str_mark", + {"m": "M", "space": "with space", "aaאבגדcc": "aaאבגדcc", "אבגד": "אבגד"}, + ) + ) + + for marker in markers: + mark_name_mapping[marker.name].append(marker) + + return MarkMatcher(mark_name_mapping) + + +@pytest.mark.parametrize( + "expr, expected", + ( + # happy cases + ("number_mark(a=1)", True), + ("number_mark(b=2)", True), + ("number_mark(a=1,b=2)", True), + ("number_mark(a=1, b=2)", True), + ("number_mark(d=999999)", True), + ("number_mark(a = 1,b= 2, c = 3)", True), + # sad cases + ("number_mark(a=6)", False), + ("number_mark(b=6)", False), + ("number_mark(a=1,b=6)", False), + ("number_mark(a=6,b=2)", False), + ("number_mark(a = 1,b= 2, c = 6)", False), + ("number_mark(a='1')", False), + ), +) +def test_keyword_expressions_with_numbers( + expr: str, expected: bool, mark_matcher: MarkMatcher +) -> None: + assert evaluate(expr, mark_matcher) is expected + + +@pytest.mark.parametrize( + "expr, expected", + ( + ("builtin_matchers_mark(x=True)", True), + ("builtin_matchers_mark(x=False)", False), + ("builtin_matchers_mark(y=True)", False), + ("builtin_matchers_mark(y=False)", True), + ("builtin_matchers_mark(z=None)", True), + ("builtin_matchers_mark(z=False)", False), + ("builtin_matchers_mark(z=True)", False), + ("builtin_matchers_mark(z=0)", False), + ("builtin_matchers_mark(z=1)", False), + ), +) +def test_builtin_matchers_keyword_expressions( # TODO: naming when decided + expr: str, expected: bool, mark_matcher: MarkMatcher +) -> None: + assert evaluate(expr, mark_matcher) is expected + + +@pytest.mark.parametrize( + "expr, expected", + ( + ("str_mark(m='M')", True), + ('str_mark(m="M")', True), + ("str_mark(aaאבגדcc='aaאבגדcc')", True), + ("str_mark(אבגד='אבגד')", True), + ("str_mark(space='with space')", True), + ("str_mark(m='wrong')", False), + ("str_mark(aaאבגדcc='wrong')", False), + ("str_mark(אבגד='wrong')", False), + ), +) +def test_str_keyword_expressions( + expr: str, expected: bool, mark_matcher: MarkMatcher +) -> None: + assert evaluate(expr, mark_matcher) is expected