Source code for macro_polo.match

"""Macro input pattern matching utilities."""

from collections.abc import Iterator, Mapping, Sequence
from dataclasses import dataclass
from enum import Enum
from functools import cache
import tokenize
from typing import Literal, TypeAlias, Union, assert_never

from . import MacroError
from ._utils import SliceView, TupleNewType
from .tokens import Delimiter, Token, TokenTree


[docs] class MacroMatchError(MacroError): """Exception raised for macro matching errors."""
MacroMatcherItem: TypeAlias = Union[ Token, 'DelimitedMacroMatcher', 'MacroMatcherVar', 'MacroMatcherRepeater', 'MacroMatcherUnion', 'MacroMatcherNegativeLookahead', ] """Union of types that can appear in a :class:`MacroMatcher`.""" MacroMatcherCapture: TypeAlias = Union[ Token, TokenTree, list['MacroMatcherCapture'], 'MacroMatcherEmptyCapture', ] """Captured token(s), possibly repeating.""" MacroMatchCaptures: TypeAlias = Mapping[str, MacroMatcherCapture]
[docs] @dataclass(frozen=True, slots=True) class MacroMatch: """Result of a successful macro match.""" size: int """Number of tokens matched.""" captures: MacroMatchCaptures """Captured tokens."""
[docs] @dataclass(frozen=True, slots=True) class MacroMatcherEmptyCapture: """An empty macro capture. Preserves nesting depth information, primarily to enable better transcription error messages. """ depth: int """Repeater nesting depth.""" def __bool__(self) -> Literal[False]: return False def __iter__(self) -> Iterator[Token]: yield from ()
[docs] class MacroMatcher(TupleNewType[MacroMatcherItem]): """A macro match pattern. :type args: :class:`MacroMatcherItem` """
[docs] def match(self, tokens: Sequence[Token]) -> MacroMatch | None: """Attempt to match against a token sequence.""" start_size = len(tokens) tokens = SliceView(tokens) captures: dict[str, MacroMatcherCapture] = {} for item in self: if isinstance(item, Token): if Delimiter.from_token(item): raise ValueError( 'delimiter tokens cannot be matched directly, ' 'use DelimitedMacroMatcher instead' ) if len(tokens) < 1 or tokens.popleft() != item: return None else: match = item.match(tokens) if match is None: return None tokens = tokens[match.size :] captures |= match.captures return MacroMatch(size=start_size - len(tokens), captures=captures)
[docs] def full_match(self, tokens: Sequence[Token]) -> MacroMatch | None: """Attempt to match against an entire token sequence.""" if (match := self.match(tokens)) and match.size == len(tokens): return match return None
[docs] @dataclass(frozen=True, slots=True) class DelimitedMacroMatcher: """A delimited macro match pattern.""" delimiter: Delimiter """The delimiter to match.""" matcher: MacroMatcher """The inner matcher."""
[docs] def match(self, tokens: Sequence[Token]) -> MacroMatch | None: """Attempt to match against a token sequence.""" if len(tokens) < 2 or not self.delimiter.matches_open(tokens[0]): return None tokens = SliceView(tokens) # Skip opening delimiter tokens.popleft() # Find closing delimiter depth = 0 for i, token in enumerate(tokens): if self.delimiter.matches_open(token): depth += 1 elif self.delimiter.matches_close(token): if depth == 0: break depth -= 1 else: # No closing delimiter return None if inner_match := self.matcher.full_match(tokens[:i]): return MacroMatch(size=inner_match.size + 2, captures=inner_match.captures) return None
[docs] class MacroMatcherVarType(Enum): """Capture-variable type.""" TOKEN = 'token' """Any non-delimiter token.""" NAME = 'name' """Any :data:`token.NAME` token.""" OP = 'op' """Any non-delimeter :data:`token.OP` token.""" NUMBER = 'number' """Any :data:`token.NUMBER` token.""" STRING = 'string' """Any :data:`token.STRING` token.""" TOKEN_TREE = 'tt' """Any non-delimiter token or a delimited sequence of tokens.""" NULL = 'null' """Always matches, capturing an empty :class:`~macro_polo.tokens.TokenTree`"""
[docs] @dataclass(frozen=True, slots=True) class MacroMatcherVar: """A capture-variable in a macro matcher.""" name: str """The name to bind captured tokens to.""" type: MacroMatcherVarType """The type of token(s) to match.""" _token_types = { MacroMatcherVarType.NAME: tokenize.NAME, MacroMatcherVarType.OP: tokenize.OP, MacroMatcherVarType.NUMBER: tokenize.NUMBER, MacroMatcherVarType.STRING: tokenize.STRING, }
[docs] def match(self, tokens: Sequence[Token]) -> MacroMatch | None: """Attempt to match against a token sequence.""" match self.type: case MacroMatcherVarType.NULL: return MacroMatch( size=0, captures={self.name: TokenTree()} if self.name != '_' else {}, ) case _ if len(tokens) < 1: return None case MacroMatcherVarType.TOKEN_TREE: tokens = SliceView(tokens) first_token = tokens.popleft() matched_tokens: list[Token] = [first_token] if delimiter := Delimiter.from_token(first_token): depth = 0 # Match until end delimiter is found while token := tokens.popleft(): matched_tokens.append(token) if delimiter.matches_open(token): depth += 1 elif delimiter.matches_close(token): if depth == 0: break depth -= 1 return MacroMatch( size=len(matched_tokens), captures=( {self.name: TokenTree(*matched_tokens)} if self.name != '_' else {} ), ) case _ if Delimiter.from_token(tokens[0]): # Delimiters can only be matched by TOKEN_TREE return None case MacroMatcherVarType.TOKEN: return MacroMatch( size=1, captures={self.name: tokens[0]} if self.name != '_' else {}, ) case ( MacroMatcherVarType.NAME | MacroMatcherVarType.OP | MacroMatcherVarType.NUMBER | MacroMatcherVarType.STRING ): token = tokens[0] if token.type != MacroMatcherVar._token_types[self.type]: return None return MacroMatch( size=1, captures={self.name: token} if self.name != '_' else {} ) case _: assert_never(self.type)
[docs] class MacroMatcherRepeaterMode(Enum): """Matcher repeat mode.""" ZERO_OR_ONE = '?' """Match ≤1 times.""" ZERO_OR_MORE = '*' """Match ≥0 times.""" ONE_OR_MORE = '+' """Match ≥1 times."""
[docs] @dataclass(frozen=True, slots=True) class MacroMatcherRepeater: """A repeated sub-matcher.""" matcher: MacroMatcher """The matcher to repeat.""" mode: MacroMatcherRepeaterMode """The repitition mode.""" sep: Token | None = None """An optional separator token.""" @property def base_captures(self) -> Mapping[str, MacroMatcherEmptyCapture]: """Get a set of empty captures for this matcher. This is used to provide empty capture lists for matchers that match zero times, allowing transcribers to handle empty captures properly. """ return _base_captures_from_matcher(self.matcher)
[docs] def match(self, tokens: Sequence[Token]) -> MacroMatch | None: """Attempt to match against a token sequence.""" start_size = len(tokens) tokens = SliceView(tokens) captures: dict[str, list[MacroMatcherCapture]] = {} first = True while True: match_sep = not first and self.sep if match_sep: if len(tokens) < 1 or tokens[0] != self.sep: break match = self.matcher.match(tokens[1:] if match_sep else tokens) if match is None: if first and self.mode is MacroMatcherRepeaterMode.ONE_OR_MORE: return None break # Only pop sep if self.matcher matches, to prevent consuming trailing # separators. if match_sep: tokens.popleft() tokens = tokens[match.size :] for name, capture in match.captures.items(): captures.setdefault(name, []).append(capture) if self.mode is MacroMatcherRepeaterMode.ZERO_OR_ONE: break first = False if not captures: return MacroMatch( size=start_size - len(tokens), captures=self.base_captures ) return MacroMatch(size=start_size - len(tokens), captures=captures)
[docs] class MacroMatcherUnion(TupleNewType[MacroMatcher]): """A union of macro matchers. The first sub-matcher to match is used. """ def __new__(cls, *args: MacroMatcher): """Create a new `MacroMatcherUnion`.""" self = super().__new__(cls, *args) if len(args) < 1: raise MacroError('Union must have at least one variant.') captures = _base_captures_from_matcher(self[0]) for matcher in self: if _base_captures_from_matcher(matcher) != captures: raise MacroError( 'All union variants must have identical capture variables at ' 'equiavelent nesting depths.' ) return self
[docs] def match(self, tokens: Sequence[Token]) -> MacroMatch | None: """Attempt to match against a token sequence.""" for matcher in self: if match := matcher.match(tokens): return match return None
[docs] class MacroMatcherNegativeLookahead(TupleNewType[MacroMatcherItem]): """A negative lookahead macro match. Matches zero tokens only if :class:`MacroMatcher` would fail to match, and fails to match otherwise. :type args: :class:`MacroMatcherItem` """ @property @cache def _matcher(self) -> MacroMatcher: return MacroMatcher(*self)
[docs] def match(self, tokens: Sequence[Token]) -> MacroMatch | None: """Attempt to match against a token sequence.""" if self._matcher.match(tokens): return None return MacroMatch(size=0, captures={})
@cache def _base_captures_from_matcher( matcher: MacroMatcher, ) -> dict[str, MacroMatcherEmptyCapture]: """Get a set of empty captures for the given pattern. The return value is the expected result of matching against this pattern, wrapped in a zero-or-one repeater, with zero matches. In other words, a dict containing an empty list for each capture variable, at the appropriate nesting level. """ captures: dict[str, MacroMatcherEmptyCapture] = {} for item in matcher: match item: case Token() | MacroMatcherNegativeLookahead(): pass case DelimitedMacroMatcher(): captures.update(_base_captures_from_matcher(item.matcher)) case MacroMatcherVar(): if item.name != '_': captures[item.name] = MacroMatcherEmptyCapture(0) case MacroMatcherRepeater(): for name, base_capture in item.base_captures.items(): captures[name] = MacroMatcherEmptyCapture(base_capture.depth + 1) case MacroMatcherUnion(): captures.update(_base_captures_from_matcher(item[0])) case _: assert_never(item) return captures