feat: support keyword arguments in marker expressions

Fixes #12281
This commit is contained in:
lovetheguitar
2024-06-20 23:09:14 +02:00
parent e8fa8dd31c
commit 15c33fbaa3
4 changed files with 304 additions and 17 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import collections
import dataclasses
from typing import AbstractSet
from typing import Collection
@@ -181,7 +182,9 @@ class KeywordMatcher:
return cls(mapped_names)
def __call__(self, subname: str) -> bool:
def __call__(self, subname: str, /, **kwargs: object) -> bool:
if kwargs:
raise UsageError("Keyword expressions do not support call parameters.")
subname = subname.lower()
names = (name.lower() for name in self._names)
@@ -211,6 +214,9 @@ def deselect_by_keyword(items: list[Item], config: Config) -> None:
items[:] = remaining
NOT_NONE_SENTINEL = object()
@dataclasses.dataclass
class MarkMatcher:
"""A matcher for markers which are present.
@@ -218,17 +224,31 @@ class MarkMatcher:
Tries to match on any marker names, attached to the given colitem.
"""
__slots__ = ("own_mark_names",)
__slots__ = ("own_mark_name_mapping",)
own_mark_names: AbstractSet[str]
own_mark_name_mapping: dict[str, list[Mark]]
@classmethod
def from_item(cls, item: Item) -> MarkMatcher:
mark_names = {mark.name for mark in item.iter_markers()}
return cls(mark_names)
mark_name_mapping = collections.defaultdict(list)
for mark in item.iter_markers():
mark_name_mapping[mark.name].append(mark)
return cls(mark_name_mapping)
def __call__(self, name: str) -> bool:
return name in self.own_mark_names
def __call__(self, name: str, /, **kwargs: object) -> bool:
if not (matches := self.own_mark_name_mapping.get(name, [])):
return False
if not kwargs:
return True
for mark in matches:
if all(
mark.kwargs.get(k, NOT_NONE_SENTINEL) == v for k, v in kwargs.items()
):
return True
return False
def deselect_by_mark(items: list[Item], config: Config) -> None:

View File

@@ -5,7 +5,8 @@ The grammar is:
expression: expr? EOF
expr: and_expr ('or' and_expr)*
and_expr: not_expr ('and' not_expr)*
not_expr: 'not' not_expr | '(' expr ')' | ident
not_expr: 'not' not_expr | '(' expr ')' | ident ( '(' name '=' value ( ', ' name '=' value )* ')')*
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
The semantics are:
@@ -20,12 +21,13 @@ from __future__ import annotations
import ast
import dataclasses
import enum
import keyword
import re
import types
from typing import Callable
from typing import Iterator
from typing import Mapping
from typing import NoReturn
from typing import Protocol
from typing import Sequence
@@ -43,6 +45,9 @@ class TokenType(enum.Enum):
NOT = "not"
IDENT = "identifier"
EOF = "end of input"
EQUAL = "="
STRING = "str"
COMMA = ","
@dataclasses.dataclass(frozen=True)
@@ -86,6 +91,27 @@ class Scanner:
elif input[pos] == ")":
yield Token(TokenType.RPAREN, ")", pos)
pos += 1
elif input[pos] == "=":
yield Token(TokenType.EQUAL, "=", pos)
pos += 1
elif input[pos] == ",":
yield Token(TokenType.COMMA, ",", pos)
pos += 1
elif (quote_char := input[pos]) == "'" or input[pos] == '"':
quote_position = input[pos + 1 :].find(quote_char)
if quote_position == -1:
raise ParseError(
pos + 1,
f'closing quote "{quote_char}" is missing',
)
value = input[pos : pos + 2 + quote_position]
if "\\" in value:
raise ParseError(
pos + 1,
"escaping not supported in marker expression",
)
yield Token(TokenType.STRING, value, pos)
pos += len(value)
else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
if match:
@@ -166,18 +192,84 @@ def not_expr(s: Scanner) -> ast.expr:
return ret
ident = s.accept(TokenType.IDENT)
if ident:
return ast.Name(IDENT_PREFIX + ident.value, ast.Load())
name = ast.Name(IDENT_PREFIX + ident.value, ast.Load())
if s.accept(TokenType.LPAREN):
ret = ast.Call(func=name, args=[], keywords=all_kwargs(s))
s.accept(TokenType.RPAREN, reject=True)
else:
ret = name
return ret
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
class MatcherAdapter(Mapping[str, bool]):
BUILTIN_MATCHERS = {"True": True, "False": False, "None": None}
def single_kwarg(s: Scanner) -> ast.keyword:
keyword_name = s.accept(TokenType.IDENT, reject=True)
assert keyword_name is not None # for mypy
if not keyword_name.value.isidentifier() or keyword.iskeyword(keyword_name.value):
raise ParseError(
keyword_name.pos + 1,
f'unexpected character/s "{keyword_name.value}"',
)
s.accept(TokenType.EQUAL, reject=True)
if value_token := s.accept(TokenType.STRING):
value: str | int | bool | None = value_token.value[1:-1] # strip quotes
else:
value_token = s.accept(TokenType.IDENT, reject=True)
assert value_token is not None # for mypy
if (
(number := value_token.value).isdigit()
or number.startswith("-")
and number[1:].isdigit()
):
value = int(number)
elif value_token.value in BUILTIN_MATCHERS:
value = BUILTIN_MATCHERS[value_token.value]
else:
raise ParseError(
value_token.pos + 1,
f'unexpected character/s "{value_token.value}"',
)
ret = ast.keyword(keyword_name.value, ast.Constant(value))
return ret
def all_kwargs(s: Scanner) -> list[ast.keyword]:
ret = [single_kwarg(s)]
while s.accept(TokenType.COMMA):
ret.append(single_kwarg(s))
return ret
class MatcherCall(Protocol):
def __call__(self, name: str, /, **kwargs: object) -> bool: ...
@dataclasses.dataclass
class MatcherNameAdapter:
matcher: MatcherCall
name: str
def __bool__(self) -> bool:
return self.matcher(self.name)
def __call__(self, **kwargs: object) -> bool:
return self.matcher(self.name, **kwargs)
class MatcherAdapter(Mapping[str, MatcherNameAdapter]):
"""Adapts a matcher function to a locals mapping as required by eval()."""
def __init__(self, matcher: Callable[[str], bool]) -> None:
def __init__(self, matcher: MatcherCall) -> None:
self.matcher = matcher
def __getitem__(self, key: str) -> bool:
return self.matcher(key[len(IDENT_PREFIX) :])
def __getitem__(self, key: str) -> MatcherNameAdapter:
return MatcherNameAdapter(matcher=self.matcher, name=key[len(IDENT_PREFIX) :])
def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
@@ -211,7 +303,7 @@ class Expression:
)
return Expression(code)
def evaluate(self, matcher: Callable[[str], bool]) -> bool:
def evaluate(self, matcher: MatcherCall) -> bool:
"""Evaluate the match expression.
:param matcher:
@@ -220,5 +312,5 @@ class Expression:
:returns: Whether the expression matches or not.
"""
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
ret: bool = bool(eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)))
return ret