From 88d72569deda6268761156e5fd5610e47df13f31 Mon Sep 17 00:00:00 2001 From: Matthias Guenther Date: Fri, 14 Feb 2025 10:52:45 -0800 Subject: [PATCH] Make `generate_hlo_test_checks.py` backwards-compatible with Python 3.9. PiperOrigin-RevId: 726986626 --- xla/hlo/tools/generate_hlo_test_checks.py | 72 +++++++++++++---------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/xla/hlo/tools/generate_hlo_test_checks.py b/xla/hlo/tools/generate_hlo_test_checks.py index f849ffdd36464..7ea6c03729abf 100755 --- a/xla/hlo/tools/generate_hlo_test_checks.py +++ b/xla/hlo/tools/generate_hlo_test_checks.py @@ -88,6 +88,7 @@ OPT_ARGS...: {} --passes=foo,bar """ +from __future__ import annotations import argparse import collections from collections.abc import Callable, Iterator @@ -102,11 +103,11 @@ import subprocess import sys import tempfile -from typing import Generic, Self, TypeAlias, TypeVar, cast +from typing import Generic, Optional, TypeVar, Union, cast _T = TypeVar("_T") -ListOrTuple: TypeAlias = list[_T] | tuple[_T, ...] +ListOrTuple = Union[list[_T], tuple[_T, ...]] _SCRIPT_NAME: str = os.path.basename(__file__) _BANNER_COMMENT_LINE: str = ( @@ -120,6 +121,10 @@ r"^((?:[\w]+ )*)%([\w\-]+(?:\.[\w\-]+)*?(\.\d+)?)(?= [({])" ) +_DIRECTIVE_REGEX_MATCHER: re.Pattern[str] = re.compile( + r"^ *// *(CHECK(?:-(?:COUNT|DAG|EMPTY|LABEL|NEXT|NOT|SAME))?|COM|RUN):") +_DIRECTIVE_MAX_STRING_LENGTH: int = len("CHECK-LABEL") + class DirectiveComment(enum.Enum): """LLVM directive comments. @@ -140,11 +145,6 @@ class DirectiveComment(enum.Enum): COM = 8 RUN = 9 - _REGEX_MATCHER: re.Pattern[str] = enum.nonmember(re.compile( - r"^ *// *(CHECK(?:-(?:COUNT|DAG|EMPTY|LABEL|NEXT|NOT|SAME))?|COM|RUN):")) - - _MAX_CHECK_STRING_WIDTH: int = enum.nonmember(len("CHECK-LABEL")) - def __str__(self) -> str: # Note: `DirectiveComment.padding_width` makes use of the fact that # `len(str(self)) == len(self.name)` to avoid an unnecessary `str.replace()` @@ -161,21 +161,21 @@ def padding_width(self) -> int: # Note: This uses the fact that `len(str(self)) == len(self.name)` to avoid # an unnecessary `str.replace()` operation. If that assumption ever ceases # to hold in all cases, replace `len(self.name)` with `len(str(self))`. - return self._MAX_CHECK_STRING_WIDTH - len(self.name) if self.is_check else 0 + return _DIRECTIVE_MAX_STRING_LENGTH - len(self.name) if self.is_check else 0 @property def line_prefix(self) -> str: return f"// {self}: {' ' * self.padding_width}" @classmethod - def parse(cls, check_string: str) -> Self: + def parse(cls, check_string: str) -> DirectiveComment: """Parses a string representation of a DirectiveComment.""" return cls[check_string.replace("-", "_")] @classmethod - def extract_from_line(cls, line: str) -> Self | None: + def extract_from_line(cls, line: str) -> Optional[DirectiveComment]: """Returns the FileCheck/RUN directive, if any, used by a line of text.""" - match = cls._REGEX_MATCHER.match(line) + match = _DIRECTIVE_REGEX_MATCHER.match(line) return None if match is None else cls.parse(match.group(1)) def format_line(self, line_text: str) -> str: @@ -287,12 +287,14 @@ def __init__( self, input_stream: Iterator[_T], select_buffer: Callable[ - [_T], collections.deque[_T] | tuple[collections.deque[_T], ...] | None + [_T], + Union[collections.deque[_T], tuple[collections.deque[_T], ...], None], ], ): self._input_stream: Iterator[_T] = input_stream self._select_buffer: Callable[ - [_T], collections.deque[_T] | tuple[collections.deque[_T], ...] | None + [_T], + Union[collections.deque[_T], tuple[collections.deque[_T], ...], None], ] = select_buffer def next_in_buffer(self, target_buffer: collections.deque[_T]) -> _T: @@ -336,10 +338,14 @@ def next_in_buffer(self, target_buffer: collections.deque[_T]) -> _T: return item continue + T = TypeVar("T", bound=_T) + expected_type = Union[ + collections.deque[T], tuple[collections.deque[T], ...], None + ] raise TypeError( f"`{self._select_buffer}` returned a value of type " f"`{type(which_buffer).__name__}`; expected one of " - f"`{collections.deque[_T] | tuple[collections.deque[_T], ...] | None}" + f"`{expected_type}" f"`." ) @@ -364,7 +370,7 @@ class HloStreamSplitter: def __init__(self, input_stream: Iterator[str], - record_directives: set[DirectiveComment] | None = None): + record_directives: Optional[set[DirectiveComment]] = None): if record_directives is None: record_directives = set() @@ -395,7 +401,7 @@ def directive_history(self, directive: DirectiveComment) -> Iterator[str]: def _select_buffer( self, line: str - ) -> collections.deque[str] | tuple[collections.deque[str], ...] | None: + ) -> Union[collections.deque[str], tuple[collections.deque[str], ...], None]: directive = DirectiveComment.extract_from_line(line) if directive is None: @@ -517,7 +523,7 @@ def _replace_symbol_names_with_regex_captures( The transformed lines of the HLO test. """ for line in input_stream: - match: re.Match[str] | None = self._CHECK_LINE_REGEX.match(line) + match: Optional[re.Match[str]] = self._CHECK_LINE_REGEX.match(line) if match is None: if line == self._END_OF_FUNCTION_SCOPE_SENTINEL_VALUE: @@ -634,7 +640,7 @@ def __init__( self, optimizer_path: str, optimizer_args: ListOrTuple[str], - worker_count: int | None = None, + worker_count: Optional[int] = None, expand_to_input: str = _DEFAULT_INPUT_FILE_EXPANSION_TOKEN, ): """TestCheckWriter constructor. @@ -656,15 +662,15 @@ def __init__( self._optimizer_path: str = optimizer_path self._optimizer_args: list[str] = list(optimizer_args) - self._worker_count: int | None = worker_count + self._worker_count: Optional[int] = worker_count self._expand_to_input: str = expand_to_input # The worker pool, if applicable, is created in `__enter__` and destroyed in # `__exit__`. - self._worker_pool: multiprocessing.pool.Pool | None = None + self._worker_pool: Optional[multiprocessing.pool.Pool] = None self._context_manager_active: bool = False - def __enter__(self) -> Self: + def __enter__(self) -> TestCheckWriter: """Context manager setup. Initializes `self._worker_pool` if `self._worker_count` is either `None` @@ -733,9 +739,11 @@ def split_test_cases( def join_test_cases( self, - test_cases: Iterator[Iterator[str]] | Iterator[tuple[Iterator[str], ...]], + test_cases: Union[ + Iterator[Iterator[str]], Iterator[tuple[Iterator[str], ...]] + ], num_outputs: int = 1, - ) -> Iterator[str] | tuple[Iterator[str], ...]: + ) -> Union[Iterator[str], tuple[Iterator[str], ...]]: """Concatenates the output stream(s) from each test case in `test_cases`. Args: @@ -860,10 +868,12 @@ def _join_test_cases_n_ary( def for_each_test_case( self, test_file: Iterator[str], - transformation: (Callable[[Iterator[str]], Iterator[str]] | - Callable[[Iterator[str]], tuple[Iterator[str], ...]]), + transformation: Union[ + Callable[[Iterator[str]], Iterator[str]], + Callable[[Iterator[str]], tuple[Iterator[str], ...]], + ], num_outputs: int = 1, - ) -> Iterator[str] | tuple[Iterator[str], ...]: + ) -> Union[Iterator[str], tuple[Iterator[str], ...]]: """Applies `transformation` to each test case in `test_file`. Args: @@ -896,7 +906,7 @@ def for_each_test_case( test_cases = self.split_test_cases(test_file) transformed_test_cases = cast( - Iterator[Iterator[str]] | Iterator[tuple[Iterator[str], ...]], + Union[Iterator[Iterator[str]], Iterator[tuple[Iterator[str], ...]]], ( (transformation(test_case) for test_case in test_cases) if self._worker_pool is None @@ -970,8 +980,8 @@ def annotate_test_file(self, test_file: Iterator[str]) -> Iterator[str]: def transform_and_print_file( self, file_path: str, - transformation: Callable[[Iterator[str]], Iterator[str]] | None = None, - output_stream: io.TextIOBase = sys.stdout, + transformation: Optional[Callable[[Iterator[str]], Iterator[str]]] = None, + output_stream: io.TextIOBase = cast(io.TextIOBase, sys.stdout), ) -> None: """Reads from `file_path`, applies a transformation, and prints to `stdout`. @@ -999,7 +1009,7 @@ def transform_and_print_file( def transform_and_overwrite_file( self, file_path: str, - transformation: Callable[[Iterator[str]], Iterator[str]] | None = None, + transformation: Optional[Callable[[Iterator[str]], Iterator[str]]] = None, ) -> None: """Transforms the contents of `file_path`, overwriting the file. @@ -1026,7 +1036,7 @@ def transform_and_overwrite_file( def parse_args( - string_args: ListOrTuple[str] | None = None, + string_args: Optional[ListOrTuple[str]] = None, ) -> argparse.Namespace: """Parses the command-line arguments passed into this script.""" if string_args is None: