From 44a26b41b8cea35537b171bc115d39ffca8ecf79 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Jan 2025 11:23:17 -0800 Subject: [PATCH 1/7] Target py310 and modernize codebase with ruff --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 60fe630b1378b..7279efcb0a534 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ reportMissingImports = false [tool.ruff] # NOTE: Do not create an exclude list. Edit .lintrunner.toml instead -target-version = "py38" +target-version = "py310" line-length = 120 [tool.ruff.lint] From ffda1c2c86504fc3bc61ed9f0d1a45ad9c6742c0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Jan 2025 11:24:36 -0800 Subject: [PATCH 2/7] UP038 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7279efcb0a534..c30201e1b2745 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ ignore = [ "SIM108", # We don't encourage ternary operators "SIM114", # Don't combine if branches for debugability "SIM116", # Don't use dict lookup to replace if-else + "UP038", # Using X | Y in isinstance checks is a little aggresive ] ignore-init-module-imports = true From 8c372c57427e826664b33e5a2e5671c3ac9b1e99 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Jan 2025 11:27:42 -0800 Subject: [PATCH 3/7] Format --- .../tools/MauiModelTester/create_test_data.py | 9 +- onnxruntime/python/backend/backend_rep.py | 2 - .../onnxruntime_inference_collection.py | 5 +- .../create_custom_op_wrapper.py | 7 +- .../kernels/kernel_explorer.py | 2 +- onnxruntime/python/tools/offline_tuning.py | 10 +- .../tools/quantization/base_quantizer.py | 4 +- .../python/tools/quantization/calibrate.py | 44 +++---- .../quantization/matmul_bnb4_quantizer.py | 7 +- .../tools/quantization/qdq_loss_debug.py | 50 ++++---- .../python/tools/quantization/quantize.py | 3 +- .../tools/quantization/shape_inference.py | 7 +- .../python/tools/symbolic_shape_infer.py | 29 ++--- .../python/tools/tensorrt/perf/benchmark.py | 2 +- .../tools/tensorrt/perf/build/build_image.py | 5 +- .../perf/parse_mem_concurrency_test.py | 6 +- .../tools/transformers/benchmark_helper.py | 8 +- .../tools/transformers/bert_perf_test.py | 5 +- .../tools/transformers/bert_test_data.py | 25 ++-- .../tools/transformers/convert_generation.py | 22 ++-- .../transformers/convert_to_packing_mode.py | 19 ++- .../python/tools/transformers/float16.py | 3 +- .../tools/transformers/fusion_attention.py | 41 ++++--- .../transformers/fusion_attention_clip.py | 3 +- .../transformers/fusion_attention_sam2.py | 5 +- .../transformers/fusion_attention_unet.py | 7 +- .../transformers/fusion_attention_vae.py | 5 +- .../python/tools/transformers/fusion_base.py | 21 ++-- .../tools/transformers/fusion_bias_add.py | 3 +- .../transformers/fusion_biassplitgelu.py | 3 +- .../tools/transformers/fusion_embedlayer.py | 11 +- .../tools/transformers/fusion_fastgelu.py | 11 +- .../python/tools/transformers/fusion_gelu.py | 9 +- .../tools/transformers/fusion_gemmfastgelu.py | 9 +- .../tools/transformers/fusion_group_norm.py | 3 +- .../tools/transformers/fusion_layernorm.py | 9 +- .../tools/transformers/fusion_mha_mmdit.py | 13 +-- .../tools/transformers/fusion_nhwc_conv.py | 3 +- .../transformers/fusion_qordered_attention.py | 3 +- .../transformers/fusion_qordered_gelu.py | 3 +- .../transformers/fusion_qordered_layernorm.py | 3 +- .../transformers/fusion_qordered_matmul.py | 3 +- .../transformers/fusion_rotary_attention.py | 7 +- .../python/tools/transformers/fusion_shape.py | 9 +- .../fusion_simplified_layernorm.py | 3 +- .../transformers/fusion_skip_group_norm.py | 3 +- .../tools/transformers/fusion_transpose.py | 11 +- .../python/tools/transformers/fusion_utils.py | 9 +- .../tools/transformers/io_binding_helper.py | 21 ++-- .../transformers/large_model_exporter.py | 7 +- .../python/tools/transformers/machine_info.py | 17 ++- .../python/tools/transformers/metrics.py | 51 ++++---- .../models/bart/utils/export_helper.py | 13 +-- .../bart/utils/export_summarization_edinit.py | 6 +- .../export_summarization_enc_dec_past.py | 4 +- .../transformers/models/bert/eval_squad.py | 8 +- .../transformers/models/gpt2/gpt2_helper.py | 17 ++- .../models/llama/convert_to_onnx.py | 8 +- .../models/longformer/benchmark_longformer.py | 20 ++-- .../models/longformer/longformer_helper.py | 13 +-- .../models/sam2/benchmark_sam2.py | 6 +- .../transformers/models/sam2/image_encoder.py | 2 +- .../transformers/models/sam2/sam2_demo.py | 7 +- .../models/sam2/sam2_image_onnx_predictor.py | 17 ++- .../transformers/models/sam2/sam2_utils.py | 6 +- .../models/stable_diffusion/demo_utils.py | 8 +- .../stable_diffusion/diffusion_models.py | 27 +++-- .../stable_diffusion/diffusion_schedulers.py | 11 +- .../models/stable_diffusion/engine_builder.py | 3 +- .../engine_builder_ort_cuda.py | 11 +- .../stable_diffusion/optimize_pipeline.py | 15 ++- .../pipeline_stable_diffusion.py | 20 ++-- .../stable_diffusion/test/check_image.py | 5 +- .../transformers/models/t5/past_helper.py | 11 +- .../transformers/models/t5/t5_decoder.py | 15 ++- .../transformers/models/t5/t5_encoder.py | 5 +- .../models/t5/t5_encoder_decoder_init.py | 9 +- .../tools/transformers/models/t5/t5_helper.py | 9 +- .../models/whisper/whisper_chain.py | 2 +- .../models/whisper/whisper_decoder.py | 9 +- .../models/whisper/whisper_encoder.py | 3 +- .../whisper/whisper_encoder_decoder_init.py | 5 +- .../models/whisper/whisper_helper.py | 9 +- .../python/tools/transformers/onnx_model.py | 25 ++-- .../tools/transformers/onnx_model_bart.py | 3 +- .../tools/transformers/onnx_model_bert.py | 5 +- .../transformers/onnx_model_conformer.py | 3 +- .../tools/transformers/onnx_model_mmdit.py | 5 +- .../tools/transformers/onnx_model_phi.py | 27 +++-- .../tools/transformers/onnx_model_sam2.py | 7 +- .../tools/transformers/onnx_model_t5.py | 7 +- .../tools/transformers/onnx_model_tnlr.py | 3 +- .../tools/transformers/onnx_model_unet.py | 7 +- .../tools/transformers/onnx_model_vae.py | 3 +- .../python/tools/transformers/onnx_utils.py | 2 +- .../python/tools/transformers/optimizer.py | 23 ++-- .../tools/transformers/shape_infer_helper.py | 5 +- .../tools/transformers/shape_optimizer.py | 5 +- .../multihead_attention_op_test_data_gen.py | 13 +-- onnxruntime/test/providers/cpu/rnn/LSTM.py | 8 +- .../cpu/tensor/affine_grid_test_gen.py | 4 +- .../contrib_ops/onnx_contrib_ops_helper.py | 4 +- .../test/python/onnx_backend_test_series.py | 5 +- .../python/onnxruntime_test_distributed.py | 44 +++---- .../onnxruntime_test_python_cudagraph.py | 5 +- .../onnxruntime_test_python_dmlgraph.py | 5 +- ...time_test_python_nested_control_flow_op.py | 6 +- ...untime_test_python_symbolic_shape_infer.py | 2 +- .../test/python/onnxruntime_test_scatternd.py | 2 +- .../python/quantization/test_calibration.py | 8 +- .../quantization/test_op_matmul_4bits.py | 13 +-- .../quantization/test_op_matmul_bnb4.py | 7 +- .../test/python/quantization/test_op_pad.py | 6 +- .../quantization/test_qdq_loss_debug.py | 7 +- .../test_tensor_quant_overrides_option.py | 2 +- .../python/test_pytorch_export_contrib_ops.py | 5 +- .../test/python/transformers/benchmark_gqa.py | 12 +- .../transformers/benchmark_gqa_windows.py | 3 +- .../test/python/transformers/benchmark_mha.py | 14 +-- .../transformers/bert_model_generator.py | 3 +- .../transformers/conformer_model_generator.py | 3 +- .../transformers/gpt2_model_generator.py | 1 - .../test/python/transformers/rotary_flash.py | 36 +++--- .../transformers/test_gemmfastgelu_fusion.py | 3 +- .../python/transformers/test_group_norm.py | 13 +-- .../test/python/transformers/test_mha.py | 19 ++- .../test_parity_decoder_attention.py | 19 ++- .../test_rotary_embedding_fusion.py | 7 +- .../transformers/test_rotary_mha_fusion.py | 13 +-- .../test_simplified_layernorm_fusion.py | 9 +- .../test_skip_layer_norm_fusion.py | 9 +- .../transformers/test_sparse_attention.py | 9 +- .../transformers/whisper_model_generator.py | 3 +- onnxruntime/test/testdata/CNTK/gen.py | 2 +- .../testdata/sparse_initializer_as_output.py | 2 +- .../test/testdata/sparse_to_dense_matmul.py | 2 +- .../orttraining/python/training/artifacts.py | 19 ++- .../gradient_graph/_gradient_graph_tools.py | 7 +- .../python/training/onnxblock/_graph_utils.py | 3 +- .../onnxblock/_training_graph_utils.py | 25 ++-- .../python/training/onnxblock/blocks.py | 20 ++-- .../training/onnxblock/checkpoint_utils.py | 7 +- .../python/training/onnxblock/loss/loss.py | 5 +- .../python/training/onnxblock/onnxblock.py | 5 +- .../python/training/onnxblock/optim/optim.py | 19 ++- .../python/training/optim/_ds_modifier.py | 2 +- .../python/training/ort_triton/_cache.py | 3 +- .../python/training/ort_triton/_codegen.py | 8 +- .../python/training/ort_triton/_common.py | 56 ++++----- .../python/training/ort_triton/_decompose.py | 4 +- .../python/training/ort_triton/_ir.py | 110 +++++++++--------- .../python/training/ort_triton/_lowering.py | 64 +++++----- .../training/ort_triton/_sorted_graph.py | 25 ++-- .../training/ort_triton/_sympy_utils.py | 8 +- .../python/training/ort_triton/_utils.py | 8 +- .../training/ort_triton/kernel/_flash_attn.py | 13 +-- .../python/training/ort_triton/kernel/_mm.py | 7 +- .../training/ort_triton/triton_op_executor.py | 9 +- .../_custom_autograd_function_exporter.py | 2 +- .../ortmodule/_custom_op_symbolic_registry.py | 2 +- .../training/ortmodule/_execution_agent.py | 3 +- .../python/training/ortmodule/_fallback.py | 3 +- .../ortmodule/_graph_execution_manager.py | 9 +- .../_graph_execution_manager_factory.py | 3 +- .../ortmodule/_graph_transition_manager.py | 2 +- .../training/ortmodule/_inference_manager.py | 3 +- .../python/training/ortmodule/_io.py | 80 ++++++------- .../python/training/ortmodule/_logger.py | 12 +- .../python/training/ortmodule/_onnx_models.py | 3 +- .../training/ortmodule/_runtime_inspector.py | 23 ++-- .../ortmodule/_torch_module_interface.py | 13 ++- .../training/ortmodule/_torch_module_ort.py | 13 ++- .../ortmodule/_torch_module_pytorch.py | 13 ++- .../training/ortmodule/_training_manager.py | 5 +- .../python/training/ortmodule/_utils.py | 16 +-- .../ortmodule/_zero_stage3_compatibility.py | 37 +++--- .../ortmodule/graph_optimizer_registry.py | 2 +- .../ortmodule/graph_optimizers/_aten_attn.py | 10 +- .../ortmodule/graph_optimizers/utils.py | 17 +-- .../python/training/ortmodule/ortmodule.py | 18 +-- .../python/training/utils/data/sampler.py | 8 +- .../utils/hooks/_statistics_subscriber.py | 13 +-- .../training/utils/hooks/_subscriber_base.py | 15 ++- .../utils/hooks/_subscriber_manager.py | 19 ++- .../utils/hooks/_zero_offload_subscriber.py | 33 +++--- .../python/training/utils/ptable.py | 10 +- .../python/training/utils/torch_io_helper.py | 52 ++++----- .../python/training/utils/torch_type_map.py | 4 +- .../orttraining/test/python/_test_helpers.py | 4 +- .../test/python/orttraining_test_dort.py | 6 +- ...aining_test_experimental_gradient_graph.py | 2 +- .../test/python/orttraining_test_gru.py | 4 +- ...orttraining_test_hierarchical_ortmodule.py | 2 +- .../test/python/orttraining_test_lort.py | 2 +- .../test/python/orttraining_test_lstm.py | 4 +- .../orttraining_test_ort_apis_onnxblock.py | 4 +- .../orttraining_test_ort_pipeline_module.py | 11 +- .../python/orttraining_test_ortmodule_api.py | 24 ++-- .../orttraining_test_ortmodule_autograd.py | 19 ++- .../orttraining_test_ortmodule_onnx_ops.py | 2 +- .../orttraining_test_ortmodule_triton.py | 4 +- .../test/python/orttraining_test_ortvalue.py | 6 +- .../test/python/orttraining_test_sampler.py | 10 +- .../test/python/orttraining_test_utilities.py | 9 +- orttraining/tools/ci_test/compare_results.py | 2 +- .../tools/scripts/nv_run_pretraining.py | 2 +- tools/ci_build/build.py | 2 +- .../github/apple/package_assembly_utils.py | 7 +- tools/ci_build/op_registration_utils.py | 13 +-- tools/ci_build/op_registration_validator.py | 14 +-- tools/ci_build/reduce_op_kernels.py | 22 ++-- ...ptimizer_opset_version_updates_required.py | 7 +- tools/python/gen_contrib_doc.py | 2 +- tools/python/onnx2tfevents.py | 7 +- tools/python/ort_test_dir_utils.py | 2 +- tools/python/run_CIs_for_branch.py | 3 +- tools/python/run_CIs_for_external_pr.py | 3 +- tools/python/run_adb.py | 3 +- tools/python/sparsify_initializers.py | 3 +- tools/python/util/android/android.py | 4 +- tools/python/util/file_utils.py | 6 +- tools/python/util/onnx_model_utils.py | 5 +- .../operator_type_usage_processors.py | 27 ++--- 223 files changed, 1161 insertions(+), 1298 deletions(-) diff --git a/csharp/tools/MauiModelTester/create_test_data.py b/csharp/tools/MauiModelTester/create_test_data.py index 6c57c71f94216..d73fd950a7bc0 100644 --- a/csharp/tools/MauiModelTester/create_test_data.py +++ b/csharp/tools/MauiModelTester/create_test_data.py @@ -2,7 +2,6 @@ import shutil import sys from pathlib import Path -from typing import Dict, List, Optional import numpy as np @@ -84,7 +83,7 @@ def parse_args(): return args -def create_existing_data_map(pb_files: List[Path]): +def create_existing_data_map(pb_files: list[Path]): import onnx_test_data_utils as data_utils data_map = {} @@ -98,9 +97,9 @@ def create_existing_data_map(pb_files: List[Path]): def add_model_and_test_data_to_app( model_path: Path, - symbolic_dims: Optional[Dict[str, int]] = None, - input_map: Optional[Dict[str, np.ndarray]] = None, - output_map: Optional[Dict[str, np.ndarray]] = None, + symbolic_dims: dict[str, int] | None = None, + input_map: dict[str, np.ndarray] | None = None, + output_map: dict[str, np.ndarray] | None = None, ): import ort_test_dir_utils as utils diff --git a/onnxruntime/python/backend/backend_rep.py b/onnxruntime/python/backend/backend_rep.py index af785b71c5f55..a30569d004d34 100644 --- a/onnxruntime/python/backend/backend_rep.py +++ b/onnxruntime/python/backend/backend_rep.py @@ -6,8 +6,6 @@ Implements ONNX's backend API. """ -from typing import Any, Tuple # noqa: F401 - from onnx.backend.base import BackendRep from onnxruntime import RunOptions diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index c12efc7fdfc9b..a3741abc48077 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -9,7 +9,8 @@ import os import typing import warnings -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from onnxruntime.capi import _pybind_state as C @@ -143,7 +144,7 @@ def set_provider_options(name, options): if not all([isinstance(options_for_provider, dict) for options_for_provider in provider_options]): raise ValueError("'provider_options' values must be dicts.") - for name, options in zip(providers, provider_options): + for name, options in zip(providers, provider_options, strict=False): set_provider_options(name, options) else: diff --git a/onnxruntime/python/tools/custom_op_wrapper/create_custom_op_wrapper.py b/onnxruntime/python/tools/custom_op_wrapper/create_custom_op_wrapper.py index e0967ef5545db..76238b982fd96 100644 --- a/onnxruntime/python/tools/custom_op_wrapper/create_custom_op_wrapper.py +++ b/onnxruntime/python/tools/custom_op_wrapper/create_custom_op_wrapper.py @@ -22,7 +22,6 @@ import os import sys from dataclasses import dataclass -from typing import List, Optional, Union import onnx from onnx import TensorProto, helper @@ -65,7 +64,7 @@ class IOInfo: index: int name: str elem_type: TensorProto.DataType - shape: Optional[List[Union[int, str]]] + shape: list[int | str] | None def str_is_int(string: str) -> bool: @@ -76,7 +75,7 @@ def str_is_int(string: str) -> bool: return False -def parse_shape(shape_str: str) -> Optional[List[Union[int, str]]]: +def parse_shape(shape_str: str) -> list[int | str] | None: try: shape = [int(s) if str_is_int(s) else s for s in shape_str.split(",")] except ValueError: @@ -204,7 +203,7 @@ def parse_arguments() -> argparse.Namespace: return parser.parse_args() -def get_attributes(attr_data_info: List[List[str]]): +def get_attributes(attr_data_info: list[list[str]]): if not attr_data_info: return {} diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py b/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py index 66e1a8052ce84..363eb3865e699 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py @@ -13,11 +13,11 @@ import sys from abc import abstractmethod from argparse import Action, ArgumentParser +from collections.abc import Callable from contextlib import contextmanager from dataclasses import dataclass from fnmatch import fnmatch from functools import wraps -from typing import Callable build_dir = os.environ.get("KERNEL_EXPLORER_BUILD_DIR", None) if build_dir is None: diff --git a/onnxruntime/python/tools/offline_tuning.py b/onnxruntime/python/tools/offline_tuning.py index c032685b70f7c..c55b515814a28 100644 --- a/onnxruntime/python/tools/offline_tuning.py +++ b/onnxruntime/python/tools/offline_tuning.py @@ -7,11 +7,11 @@ import sys from collections import OrderedDict from pprint import pprint -from typing import Any, Dict, List +from typing import Any import onnx -TuningResults = Dict[str, Any] +TuningResults = dict[str, Any] _TUNING_RESULTS_KEY = "tuning_results" @@ -32,7 +32,7 @@ def extract(model: onnx.ModelProto): return json.loads(tuning_results_prop.value) -def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False): +def embed(model: onnx.ModelProto, tuning_results: list[TuningResults], overwrite=False): idx = _find_tuning_results_in_props(model.metadata_props) assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!" @@ -47,7 +47,7 @@ def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite class Merger: class EpAndValidators: - def __init__(self, ep: str, validators: Dict[str, str]): + def __init__(self, ep: str, validators: dict[str, str]): self.ep = ep self.validators = copy.deepcopy(validators) self.key = (ep, tuple(sorted(validators.items()))) @@ -61,7 +61,7 @@ def __eq__(self, other): def __init__(self): self.ev_to_results = OrderedDict() - def merge(self, tuning_results: List[TuningResults]): + def merge(self, tuning_results: list[TuningResults]): for trs in tuning_results: self._merge_one(trs) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 0cd186bffdea0..ac11607e02710 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -4,7 +4,7 @@ # license information. # -------------------------------------------------------------------------- import logging -from typing import Any, Dict +from typing import Any import numpy as np import onnx @@ -36,7 +36,7 @@ class QuantizationParams: - def __init__(self, **data: Dict[str, Any]): + def __init__(self, **data: dict[str, Any]): self.data = {} for k, v in data.items(): if not isinstance(k, str): diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 7855f260a551a..f3bb533ac89e8 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -9,9 +9,9 @@ import itertools import os import uuid +from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple, Union import numpy as np import onnx @@ -39,7 +39,7 @@ def rel_entr(pk: np.ndarray, qk: np.ndarray) -> np.ndarray: def entropy( pk: np.ndarray, qk: np.ndarray, - base: Optional[float] = None, + base: float | None = None, axis: int = 0, ) -> np.ndarray: """ @@ -100,7 +100,7 @@ def to_dict(self): class TensorsData: - def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]]): + def __init__(self, calibration_method, data: dict[str, TensorData | tuple]): self.calibration_method = calibration_method self.data = {} for k, v in data.items(): @@ -187,8 +187,8 @@ def set_range(self, start_index: int, end_index: int): class CalibraterBase: def __init__( self, - model_path: Union[str, Path], - op_types_to_calibrate: Optional[Sequence[str]] = None, + model_path: str | Path, + op_types_to_calibrate: Sequence[str] | None = None, augmented_model_path="augmented_model.onnx", symmetric=False, use_external_data_format=False, @@ -297,8 +297,8 @@ def compute_data(self) -> TensorsData: class MinMaxCalibrater(CalibraterBase): def __init__( self, - model_path: Union[str, Path], - op_types_to_calibrate: Optional[Sequence[str]] = None, + model_path: str | Path, + op_types_to_calibrate: Sequence[str] | None = None, augmented_model_path="augmented_model.onnx", symmetric=False, use_external_data_format=False, @@ -476,7 +476,8 @@ def compute_data(self) -> TensorsData: output_names = [self.infer_session.get_outputs()[i].name for i in range(len(self.intermediate_outputs[0]))] output_dicts_list = [ - dict(zip(output_names, intermediate_output)) for intermediate_output in self.intermediate_outputs + dict(zip(output_names, intermediate_output, strict=False)) + for intermediate_output in self.intermediate_outputs ] merged_output_dict = {} @@ -507,7 +508,9 @@ def compute_data(self) -> TensorsData: else: pairs.append(tuple([min_value_array, max_value_array])) - new_calibrate_tensors_range = TensorsData(CalibrationMethod.MinMax, dict(zip(calibrate_tensor_names, pairs))) + new_calibrate_tensors_range = TensorsData( + CalibrationMethod.MinMax, dict(zip(calibrate_tensor_names, pairs, strict=False)) + ) if self.calibrate_tensors_range: self.calibrate_tensors_range = self.merge_range(self.calibrate_tensors_range, new_calibrate_tensors_range) else: @@ -519,8 +522,8 @@ def compute_data(self) -> TensorsData: class HistogramCalibrater(CalibraterBase): def __init__( self, - model_path: Union[str, Path], - op_types_to_calibrate: Optional[Sequence[str]] = None, + model_path: str | Path, + op_types_to_calibrate: Sequence[str] | None = None, augmented_model_path="augmented_model.onnx", use_external_data_format=False, method="percentile", @@ -608,7 +611,8 @@ def collect_data(self, data_reader: CalibrationDataReader): raise ValueError("No data is collected.") output_dicts_list = [ - dict(zip(output_names, intermediate_output)) for intermediate_output in self.intermediate_outputs + dict(zip(output_names, intermediate_output, strict=False)) + for intermediate_output in self.intermediate_outputs ] merged_dict = {} @@ -653,8 +657,8 @@ def compute_data(self) -> TensorsData: class EntropyCalibrater(HistogramCalibrater): def __init__( self, - model_path: Union[str, Path], - op_types_to_calibrate: Optional[Sequence[str]] = None, + model_path: str | Path, + op_types_to_calibrate: Sequence[str] | None = None, augmented_model_path="augmented_model.onnx", use_external_data_format=False, method="entropy", @@ -687,8 +691,8 @@ def __init__( class PercentileCalibrater(HistogramCalibrater): def __init__( self, - model_path: Union[str, Path], - op_types_to_calibrate: Optional[Sequence[str]] = None, + model_path: str | Path, + op_types_to_calibrate: Sequence[str] | None = None, augmented_model_path="augmented_model.onnx", use_external_data_format=False, method="percentile", @@ -721,8 +725,8 @@ def __init__( class DistributionCalibrater(HistogramCalibrater): def __init__( self, - model_path: Union[str, Path], - op_types_to_calibrate: Optional[Sequence[str]] = None, + model_path: str | Path, + op_types_to_calibrate: Sequence[str] | None = None, augmented_model_path="augmented_model.onnx", use_external_data_format=False, method="distribution", @@ -1168,8 +1172,8 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): def create_calibrator( - model: Union[str, Path], - op_types_to_calibrate: Optional[Sequence[str]] = None, + model: str | Path, + op_types_to_calibrate: Sequence[str] | None = None, augmented_model_path="augmented_model.onnx", calibrate_method=CalibrationMethod.MinMax, use_external_data_format=False, diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py index 2bf47fe1680e9..2e8ee11e2f864 100644 --- a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py @@ -7,7 +7,6 @@ import argparse import logging import os -from typing import List, Tuple import numpy as np import numpy.typing as npt @@ -44,7 +43,7 @@ def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to self.nodes_to_exclude = set(nodes_to_exclude) @staticmethod - def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: for gid in range(len(graph_path) - 1, -1, -1): graph = graph_path[gid] for tensor in graph.initializer: @@ -74,7 +73,7 @@ def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray: return (packed, absmax) - def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": @@ -129,7 +128,7 @@ def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto return matmul_bnb4_node - def _process_subgraph(self, graph_stack: List[GraphProto]): + def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] graph = graph_stack[-1] diff --git a/onnxruntime/python/tools/quantization/qdq_loss_debug.py b/onnxruntime/python/tools/quantization/qdq_loss_debug.py index f9ed844febe46..9b545f2e94a2f 100644 --- a/onnxruntime/python/tools/quantization/qdq_loss_debug.py +++ b/onnxruntime/python/tools/quantization/qdq_loss_debug.py @@ -37,8 +37,8 @@ def get_next(self): import logging import math import time +from collections.abc import Callable, Sequence from pathlib import Path -from typing import Callable, Dict, List, Optional, Sequence, Union import numpy import onnx @@ -62,9 +62,9 @@ def get_next(self): def modify_model_output_intermediate_tensors( - input_model_path: Union[str, Path], - output_model_path: Union[str, Path], - op_types_for_saving: Optional[Sequence[str]] = None, + input_model_path: str | Path, + output_model_path: str | Path, + op_types_for_saving: Sequence[str] | None = None, save_as_external_data: bool = False, ) -> None: """Augment a given ONNX model to save node input/output tensors. @@ -116,8 +116,8 @@ def collect_activations( augmented_model: str, input_reader: CalibrationDataReader, session_options=None, - execution_providers: Optional[Sequence[str]] = None, -) -> Dict[str, List[numpy.ndarray]]: + execution_providers: Sequence[str] | None = None, +) -> dict[str, list[numpy.ndarray]]: """Run augmented model and collect activations tensors. Args: @@ -154,7 +154,7 @@ def collect_activations( output_dict = {} output_info = inference_session.get_outputs() for batch in intermediate_outputs: - for output, output_data in zip(output_info, batch): + for output, output_data in zip(output_info, batch, strict=False): if output.name.endswith(_TENSOR_SAVE_POSTFIX): output_name = output.name[:-_TENSOR_SAVE_POSTFIX_LEN] output_dict.setdefault(output_name, []).append(output_data) @@ -166,10 +166,10 @@ def collect_activations( def _add_pre_post_qdq_pair( - qdq_cmp: Dict[str, Dict[str, Sequence[numpy.ndarray]]], + qdq_cmp: dict[str, dict[str, Sequence[numpy.ndarray]]], activation_name: str, - pre_qdq_tensors: Optional[Sequence[numpy.ndarray]], - post_qdq_tensors: Optional[Sequence[numpy.ndarray]], + pre_qdq_tensors: Sequence[numpy.ndarray] | None, + post_qdq_tensors: Sequence[numpy.ndarray] | None, ) -> None: if post_qdq_tensors is not None and pre_qdq_tensors is not None: qdq_cmp[activation_name] = {} @@ -178,9 +178,9 @@ def _add_pre_post_qdq_pair( def create_activation_matching( - qdq_activations: Dict[str, Sequence[numpy.ndarray]], - float_activations: Optional[Dict[str, Sequence[numpy.ndarray]]] = None, -) -> Dict[str, Dict[str, Sequence[numpy.ndarray]]]: + qdq_activations: dict[str, Sequence[numpy.ndarray]], + float_activations: dict[str, Sequence[numpy.ndarray]] | None = None, +) -> dict[str, dict[str, Sequence[numpy.ndarray]]]: """Comparing activation values to help debugging accuracy loss due to quantization. This functions takes saved activations from the QDQ model and (optionally) the @@ -210,7 +210,7 @@ def create_activation_matching( ``` """ - qdq_cmp: Dict[str, Dict[str, Sequence[numpy.ndarray]]] = {} + qdq_cmp: dict[str, dict[str, Sequence[numpy.ndarray]]] = {} for tensor_name, tensors in qdq_activations.items(): if tensor_name.endswith(QUANT_INPUT_SUFFIX): pre_name = tensor_name[: -len(QUANT_INPUT_SUFFIX)] @@ -241,7 +241,7 @@ def create_activation_matching( def _run_dequantize_linear( weight_tensor: numpy.ndarray, weight_scale: numpy.ndarray, weight_zp: numpy.ndarray, channel_axis: int -) -> Optional[numpy.ndarray]: +) -> numpy.ndarray | None: assert weight_scale.shape == weight_zp.shape if weight_zp.size == 1: return (weight_tensor - weight_zp) * weight_scale @@ -267,7 +267,7 @@ def _run_dequantize_linear( return dequantized_weights -def create_weight_matching(float_model_path: str, qdq_model_path: str) -> Dict[str, Dict[str, numpy.ndarray]]: +def create_weight_matching(float_model_path: str, qdq_model_path: str) -> dict[str, dict[str, numpy.ndarray]]: """Comparing weight values to help debugging accuracy loss due to quantization. This functions takes the float model and the qdq model, and provides a data structure for comparing @@ -288,7 +288,7 @@ def create_weight_matching(float_model_path: str, qdq_model_path: str) -> Dict[s float_onnx_model = ONNXModel(load_model_with_shape_infer(Path(float_model_path))) qdq_onnx_model = ONNXModel(load_model_with_shape_infer(Path(qdq_model_path))) - matched_weights: Dict[str, Dict[str, numpy.ndarray]] = {} + matched_weights: dict[str, dict[str, numpy.ndarray]] = {} initializers = qdq_onnx_model.initializer() for node in qdq_onnx_model.nodes(): if node.op_type != DEQUANT_OP_NAME: @@ -339,7 +339,7 @@ def create_weight_matching(float_model_path: str, qdq_model_path: str) -> Dict[s def compute_signal_to_quantization_noice_ratio( - x: Union[Sequence[numpy.ndarray], numpy.ndarray], y: Union[Sequence[numpy.ndarray], numpy.ndarray] + x: Sequence[numpy.ndarray] | numpy.ndarray, y: Sequence[numpy.ndarray] | numpy.ndarray ) -> float: if isinstance(x, numpy.ndarray): xlist = [x] @@ -363,24 +363,24 @@ def compute_signal_to_quantization_noice_ratio( def compute_weight_error( - weights_match: Dict[str, Dict[str, numpy.ndarray]], + weights_match: dict[str, dict[str, numpy.ndarray]], err_func: Callable[[numpy.ndarray, numpy.ndarray], float] = compute_signal_to_quantization_noice_ratio, -) -> Dict[str, float]: - result: Dict[str, float] = {} +) -> dict[str, float]: + result: dict[str, float] = {} for weight_name, weight_match in weights_match.items(): result[weight_name] = err_func(weight_match["float"], weight_match["dequantized"]) return result def compute_activation_error( - activations_match: Dict[str, Dict[str, Sequence[numpy.ndarray]]], + activations_match: dict[str, dict[str, Sequence[numpy.ndarray]]], err_func: Callable[ [Sequence[numpy.ndarray], Sequence[numpy.ndarray]], float ] = compute_signal_to_quantization_noice_ratio, -) -> Dict[str, Dict[str, float]]: - result: Dict[str, Dict[str, float]] = {} +) -> dict[str, dict[str, float]]: + result: dict[str, dict[str, float]] = {} for name, match in activations_match.items(): - err_result: Dict[str, float] = {} + err_result: dict[str, float] = {} err_result["qdq_err"] = err_func(match["pre_qdq"], match["post_qdq"]) float_activation = match["float"] if float_activation: diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 4ffd8b9872982..27221f9445c30 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -8,8 +8,9 @@ import copy import logging import tempfile +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable +from typing import Any import onnx diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index c07007f9d6129..63d34e1167de4 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -9,7 +9,6 @@ import tempfile import traceback from pathlib import Path -from typing import Optional, Union import onnx @@ -23,8 +22,8 @@ def quant_pre_process( - input_model: Optional[Union[str, Path, onnx.ModelProto]] = None, - output_model_path: Optional[Union[str, Path]] = None, + input_model: str | Path | onnx.ModelProto | None = None, + output_model_path: str | Path | None = None, skip_optimization: bool = False, skip_onnx_shape: bool = False, skip_symbolic_shape: bool = False, @@ -34,7 +33,7 @@ def quant_pre_process( verbose: int = 0, save_as_external_data: bool = False, all_tensors_to_one_file: bool = False, - external_data_location: Optional[str] = None, + external_data_location: str | None = None, external_data_size_threshold: int = 1024, **deprecated_kwargs, ) -> None: diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index b9ff2159028d0..b1bf9c9d537e6 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -651,7 +651,7 @@ def _compute_on_sympy_data(self, node, op_func): is_list = [isinstance(v, list) for v in values] as_list = any(is_list) if as_list: - self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)] + self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values, strict=False)] else: self.sympy_data_[node.output[0]] = op_func(values) @@ -722,21 +722,21 @@ def _compute_conv_pool_shape(self, node, channels_last=False): dilations = get_attribute(node, "dilations", [1] * rank) strides = get_attribute(node, "strides", [1] * rank) - effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] + effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations, strict=False)] pads = get_attribute(node, "pads") if pads is None: pads = [0] * (2 * rank) auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") if auto_pad != "VALID" and auto_pad != "NOTSET": try: - residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] + residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides, strict=False)] total_pads = [ max(0, (k - s) if r == 0 else (k - r)) - for k, s, r in zip(effective_kernel_shape, strides, residual) + for k, s, r in zip(effective_kernel_shape, strides, residual, strict=False) ] except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational total_pads = [ - max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) + max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides, strict=False) ] # assuming no residual if sympy throws error elif auto_pad == "VALID": total_pads = [] @@ -744,7 +744,7 @@ def _compute_conv_pool_shape(self, node, channels_last=False): total_pads = [0] * rank else: assert len(pads) == 2 * rank - total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] + total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:], strict=False)] ceil_mode = get_attribute(node, "ceil_mode", 0) for i in range(rank): @@ -815,7 +815,7 @@ def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" ) if dst_tensor_type.HasField("shape"): - for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): + for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim, strict=False)): if ds[0] != ds[1]: # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type # for sequence_type, clear the dimension @@ -1222,7 +1222,7 @@ def _infer_Loop(self, node): # noqa: N802 else: si = subgraph.input[i_out + 1] si_shape = get_shape_from_value_info(si) - for di, dims in enumerate(zip(si_shape, so_shape)): + for di, dims in enumerate(zip(si_shape, so_shape, strict=False)): if dims[0] != dims[1]: new_dim = onnx.TensorShapeProto.Dimension() new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di)) @@ -1319,7 +1319,8 @@ def _infer_Pad(self, node): # noqa: N802 if pads is not None: assert len(pads) == 2 * rank new_sympy_shape = [ - d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) + d + pad_up + pad_down + for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:], strict=False) ] self._update_computed_dims(new_sympy_shape) else: @@ -1679,7 +1680,9 @@ def _infer_Resize(self, node): # noqa: N802 if get_opset(self.out_mp_) <= 10: scales = self._try_get_value(node, 1) if scales is not None: - new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)] + new_sympy_shape = [ + sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales, strict=False) + ] self._update_computed_dims(new_sympy_shape) vi.CopyFrom( helper.make_tensor_value_info( @@ -1707,7 +1710,7 @@ def _infer_Resize(self, node): # noqa: N802 scales = list(scales) new_sympy_shape = [ sympy.simplify(sympy.floor(d * (end - start) * scale)) - for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales) + for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales, strict=False) ] self._update_computed_dims(new_sympy_shape) else: @@ -1893,7 +1896,7 @@ def handle_negative_index(index, bound): for i in axes: new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) else: - for i, s, e, t in zip(axes, starts, ends, steps): + for i, s, e, t in zip(axes, starts, ends, steps, strict=False): e = handle_negative_index(e, new_sympy_shape[i]) # noqa: PLW2901 if is_literal(e): if e >= self.int_max_: @@ -2841,7 +2844,7 @@ def get_prereq(node): self._add_suggested_merge( [ s[i] if is_literal(s[i]) else str(s[i]) - for s, i in zip(shapes, dim_idx) + for s, i in zip(shapes, dim_idx, strict=False) if i >= 0 ] ) diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index 4fa5d0c0ea034..2152a66d1f2e7 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -607,7 +607,7 @@ def validate(all_ref_outputs, all_outputs, rtol, atol, percent_mismatch): output = outputs[j] # Compare the results with reference outputs - for ref_o, o in zip(ref_output, output): + for ref_o, o in zip(ref_output, output, strict=False): # abs(desired-actual) < rtol * abs(desired) + atol try: np.testing.assert_allclose(ref_o, o, rtol, atol) diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index 541dc4978dad1..0384300b99445 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -12,7 +12,6 @@ import shlex import subprocess import sys -from typing import List, Optional TRT_DOCKER_FILES = { "8.6_cuda11.8_cudnn8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6", @@ -23,7 +22,7 @@ } -def run_cmd(cmd: List[str]) -> Optional[int]: +def run_cmd(cmd: list[str]) -> int | None: """ Runs a shell command and returns the process's return code. @@ -38,7 +37,7 @@ def run_cmd(cmd: List[str]) -> Optional[int]: return pty.spawn(cmd) -def get_common_docker_build_args(args: argparse.Namespace) -> List[str]: +def get_common_docker_build_args(args: argparse.Namespace) -> list[str]: """ Returns a list of common 'docker build' command-line arguments/options. diff --git a/onnxruntime/python/tools/tensorrt/perf/parse_mem_concurrency_test.py b/onnxruntime/python/tools/tensorrt/perf/parse_mem_concurrency_test.py index 492de13fb42b5..b308066edacad 100644 --- a/onnxruntime/python/tools/tensorrt/perf/parse_mem_concurrency_test.py +++ b/onnxruntime/python/tools/tensorrt/perf/parse_mem_concurrency_test.py @@ -103,7 +103,7 @@ def parse_concurrency_test_log(input_path, output_path): # Parse mem_test log logs = ["valgrind.log", "concurrency_test.log"] csv_paths = ["mem_test.csv", "concurrency_test.csv"] - for log, csv_path in zip(logs, csv_paths): + for log, csv_path in zip(logs, csv_paths, strict=False): if os.path.exists(log): print(f"{identifier}: Parsing {log}") if log == logs[0]: @@ -112,7 +112,9 @@ def parse_concurrency_test_log(input_path, output_path): parse_concurrency_test_log(log, csv_path) # Upload to db - for csv_path, db_table_name in zip(csv_paths, ["ep_valgrind_record", "ep_concurrencytest_record"]): + for csv_path, db_table_name in zip( + csv_paths, ["ep_valgrind_record", "ep_concurrencytest_record"], strict=False + ): if os.path.exists(csv_path): table = pd.read_csv(csv_path) write_table( diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index d88e689521593..2a210729112d7 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -16,7 +16,7 @@ from datetime import datetime from enum import Enum from time import sleep -from typing import Any, Dict, List, Optional +from typing import Any import coloredlogs import numpy @@ -405,7 +405,7 @@ def set_random_seed(seed=123): # torch.backends.cudnn.deterministic = True -def get_gpu_info() -> Optional[List[Dict[str, Any]]]: +def get_gpu_info() -> list[dict[str, Any]] | None: from py3nvml.py3nvml import ( NVMLError, nvmlDeviceGetCount, @@ -459,7 +459,7 @@ def measure_cpu_usage(self): return max_usage @abstractmethod - def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: + def measure_gpu_usage(self) -> list[dict[str, Any]] | None: raise NotImplementedError() @@ -467,7 +467,7 @@ class CudaMemoryMonitor(MemoryMonitor): def __init__(self, keep_measuring=True): super().__init__(keep_measuring) - def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: + def measure_gpu_usage(self) -> list[dict[str, Any]] | None: from py3nvml.py3nvml import ( NVMLError, nvmlDeviceGetCount, diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py index 17c5d3602bb3b..c506bf4539173 100644 --- a/onnxruntime/python/tools/transformers/bert_perf_test.py +++ b/onnxruntime/python/tools/transformers/bert_perf_test.py @@ -23,7 +23,6 @@ from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Optional import numpy as np import psutil @@ -55,8 +54,8 @@ class ModelSetting: segment_ids_name: str input_mask_name: str opt_level: int - input_tuning_results: Optional[str] - output_tuning_results: Optional[str] + input_tuning_results: str | None + output_tuning_results: str | None mask_type: int diff --git a/onnxruntime/python/tools/transformers/bert_test_data.py b/onnxruntime/python/tools/transformers/bert_test_data.py index ccf2497d61342..55a4e4e5824ed 100644 --- a/onnxruntime/python/tools/transformers/bert_test_data.py +++ b/onnxruntime/python/tools/transformers/bert_test_data.py @@ -10,7 +10,6 @@ import os import random from pathlib import Path -from typing import Dict, Optional, Tuple import numpy as np from onnx import ModelProto, TensorProto, numpy_helper @@ -157,7 +156,7 @@ def fake_input_mask_data( return data -def output_test_data(directory: str, inputs: Dict[str, np.ndarray]): +def output_test_data(directory: str, inputs: dict[str, np.ndarray]): """Output input tensors of test data to a directory Args: @@ -305,10 +304,10 @@ def get_graph_input_from_embed_node(onnx_model, embed_node, input_index): def find_bert_inputs( onnx_model: OnnxModel, - input_ids_name: Optional[str] = None, - segment_ids_name: Optional[str] = None, - input_mask_name: Optional[str] = None, -) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: + input_ids_name: str | None = None, + segment_ids_name: str | None = None, + input_mask_name: str | None = None, +) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: """Find graph inputs for BERT model. First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming. @@ -397,10 +396,10 @@ def find_bert_inputs( def get_bert_inputs( onnx_file: str, - input_ids_name: Optional[str] = None, - segment_ids_name: Optional[str] = None, - input_mask_name: Optional[str] = None, -) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: + input_ids_name: str | None = None, + segment_ids_name: str | None = None, + input_mask_name: str | None = None, +) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: """Find graph inputs for BERT model. First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming. @@ -531,9 +530,9 @@ def create_and_save_test_data( test_cases: int, seed: int, verbose: bool, - input_ids_name: Optional[str], - segment_ids_name: Optional[str], - input_mask_name: Optional[str], + input_ids_name: str | None, + segment_ids_name: str | None, + input_mask_name: str | None, only_input_tensors: bool, average_sequence_length: int, random_sequence_length: bool, diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 5a26fedb5287d..68bf9e9e69059 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -48,7 +48,7 @@ import time from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any import numpy as np import onnx @@ -86,7 +86,7 @@ def __str__(self): return self.value -def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: +def parse_arguments(argv: list[str] | None = None) -> argparse.Namespace: """Parse arguments Args: @@ -883,8 +883,8 @@ def remove_shared_initializers( graph2: GraphProto, shared_prefix: str = "shared_", min_elements: int = 1024, - signature_cache1: Optional[dict] = None, - signature_cache2: Optional[dict] = None, + signature_cache1: dict | None = None, + signature_cache2: dict | None = None, ): """Remove initializers with same value from two graphs. @@ -1005,7 +1005,7 @@ def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto def move_initializers( graph: GraphProto, min_elements: int = 1024, -) -> List[TensorProto]: +) -> list[TensorProto]: """Remove initializers of a graph, when they have number of elements larger than a threshold. Args: @@ -2585,13 +2585,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati def test_torch_performance( args: argparse.Namespace, - model: Union[GPT2LMHeadModel, T5ForConditionalGeneration], + model: GPT2LMHeadModel | T5ForConditionalGeneration, input_ids: torch.Tensor, attention_mask: torch.Tensor, eos_token_id: int, pad_token_id: int, - bad_words_ids: List[List[int]], -) -> Dict[str, Any]: + bad_words_ids: list[list[int]], +) -> dict[str, Any]: """Test PyTorch performance of text generation. Args: @@ -2661,7 +2661,7 @@ def create_attention_mask(input_ids, pad_token_id): return attention_mask -def test_gpt_model(args: argparse.Namespace, sentences: Optional[List[str]] = None, is_greedy: bool = False): +def test_gpt_model(args: argparse.Namespace, sentences: list[str] | None = None, is_greedy: bool = False): """Test GPT-2 model Args: @@ -2872,7 +2872,7 @@ def test_gpt_model(args: argparse.Namespace, sentences: Optional[List[str]] = No return output -def test_t5_model(args: argparse.Namespace, sentences: Optional[List[str]] = None): +def test_t5_model(args: argparse.Namespace, sentences: list[str] | None = None): """Test T5 or MT5 model Args: @@ -3061,7 +3061,7 @@ def test_t5_model(args: argparse.Namespace, sentences: Optional[List[str]] = Non return output -def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None): +def main(argv: list[str] | None = None, sentences: list[str] | None = None): """Main entry function Args: diff --git a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py index e854312cae826..9a6388b3f350d 100644 --- a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py +++ b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py @@ -6,7 +6,6 @@ import argparse import logging import os -from typing import List, Union import coloredlogs from constants import ( @@ -26,15 +25,15 @@ class PackingAttentionBase: def __init__(self, model: OnnxModel, attention_op_type: str): self.model: OnnxModel = model - self.nodes_to_remove: List = [] - self.nodes_to_add: List = [] + self.nodes_to_remove: list = [] + self.nodes_to_add: list = [] self.prune_graph: bool = False self.node_name_to_graph_name: dict = {} self.this_graph_name: str = self.model.model.graph.name self.attention_op_type = attention_op_type self.attention_nodes = self.model.get_nodes_by_op_type(attention_op_type) - def _try_getting_attention_mask(self) -> Union[str, None]: + def _try_getting_attention_mask(self) -> str | None: mask_index = ( AttentionInputIDs.MASK_INDEX if self.attention_op_type == Operators.ATTENTION @@ -54,13 +53,13 @@ def _try_getting_attention_mask(self) -> Union[str, None]: return attention_mask - def _try_getting_first_attention(self) -> Union[NodeProto, None]: + def _try_getting_first_attention(self) -> NodeProto | None: if len(self.attention_nodes) <= 0: return None return self.attention_nodes[0] - def _try_getting_last_layernorm(self) -> Union[NodeProto, None]: + def _try_getting_last_layernorm(self) -> NodeProto | None: last_layernorm_node = None for node in self.model.nodes(): if node.op_type == Operators.LAYERNORM or node.op_type == Operators.SKIPLAYERNORM: @@ -70,7 +69,7 @@ def _try_getting_last_layernorm(self) -> Union[NodeProto, None]: def _are_attentions_supported(self) -> bool: raise NotImplementedError() - def _insert_removepadding_node(self, inputs: List[str], outputs: List[str]) -> None: + def _insert_removepadding_node(self, inputs: list[str], outputs: list[str]) -> None: new_node = helper.make_node( Operators.REMOVEPADDING, inputs=inputs, @@ -82,7 +81,7 @@ def _insert_removepadding_node(self, inputs: List[str], outputs: List[str]) -> N self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name - def _insert_restorepadding_node(self, inputs: List[str], outputs: List[str]) -> None: + def _insert_restorepadding_node(self, inputs: list[str], outputs: list[str]) -> None: new_node = helper.make_node( Operators.RESTOREPADDING, inputs=inputs, @@ -97,7 +96,7 @@ def _insert_restorepadding_node(self, inputs: List[str], outputs: List[str]) -> def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None: raise NotImplementedError() - def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]: + def _get_input_to_remove_padding(self, first_attention_node) -> str | None: if self.attention_op_type == Operators.ATTENTION: return first_attention_node.input[AttentionInputIDs.INPUT] return None @@ -306,7 +305,7 @@ def _replace_attention_with_packing_attention(self, token_offset: str, cumulativ logger.info("Converted %d MultiHeadAttention nodes to PackedMultiHeadAttention.", len(self.attention_nodes)) logger.info("Converted %d GatedRelativePositionBias nodes to packing mode.", gated_relative_pos_bias_count) - def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]: + def _get_input_to_remove_padding(self, first_attention_node) -> str | None: # When there are query, key and value inputs, we need to find the first input of the parent MatMul node. matmul = self.model.get_parent(first_attention_node, 0) if matmul and matmul.op_type == "MatMul": diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index e9ac4a64f9fe5..349f5bb51fe47 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -16,7 +16,6 @@ import logging import os import tempfile -from typing import Dict import numpy as np import onnx @@ -304,7 +303,7 @@ def convert_float_to_float16( value_info_list.append(new_value_info) io_casts.add(node_name) - fp32_initializers: Dict[str, InitializerTracker] = {} + fp32_initializers: dict[str, InitializerTracker] = {} while queue: next_level = [] for q in queue: diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 56b5ae93e7221..c02cf5cbb4e54 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import List, Optional, Tuple, Union import numpy as np from fusion_base import Fusion @@ -42,7 +41,7 @@ def get_first_mask(self): assert len(self.mask_indice) > 0 return next(iter(self.mask_indice)) - def process_mask(self, mask_2d: str) -> Optional[str]: + def process_mask(self, mask_2d: str) -> str | None: if self.mask_format == AttentionMaskFormat.NoMask: return None @@ -111,10 +110,10 @@ def __init__( model: OnnxModel, hidden_size: int, num_heads: int, - attention_mask: Optional[AttentionMask] = None, + attention_mask: AttentionMask | None = None, use_multi_head_attention: bool = False, disable_multi_head_attention_bias: bool = False, - search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006 + search_op_types: list[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006 ): attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention" super().__init__(model, attention_op_name, search_op_types) @@ -132,7 +131,7 @@ def __init__( self.shape_infer = None self.shape_infer_done = True - def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]: + def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> tuple[int, int]: """ Detect num_heads and hidden_size from Concat node in the following subgraph: @@ -163,7 +162,7 @@ def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[ return self.num_heads, self.hidden_size - def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]: """Detect num_heads and hidden_size from a reshape node. Args: @@ -358,10 +357,10 @@ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str): def create_combined_qkv_bias( self, q_add: NodeProto, - k_add: Union[NodeProto, None], - v_add: Union[NodeProto, None], + k_add: NodeProto | None, + v_add: NodeProto | None, name_prefix: str, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0]) qb = NumpyHelper.to_array(q_bias) kb = np.zeros_like(qb) @@ -391,9 +390,9 @@ def create_packed_qkv_matmul_node( k_matmul: NodeProto, v_matmul: NodeProto, q_add: NodeProto, - k_add: Union[NodeProto, None], - v_add: Union[NodeProto, None], - ) -> Tuple[NodeProto, NodeProto, NodeProto]: + k_add: NodeProto | None, + v_add: NodeProto | None, + ) -> tuple[NodeProto, NodeProto, NodeProto]: """Create packed QKV MatMul node before MultiHeadAttention node. This is for the scenario where an Attention node should be created but cannot be created because past_key and past_value are separate inputs and not one concatenated input. @@ -532,11 +531,11 @@ def create_packed_qkv_matmul_node( def create_multihead_attention_node( self, q_matmul: NodeProto, - k_matmul: Union[NodeProto, str, None], - v_matmul: Union[NodeProto, str, None], + k_matmul: NodeProto | str | None, + v_matmul: NodeProto | str | None, q_add: NodeProto, - k_add: Union[NodeProto, None], - v_add: Union[NodeProto, None], + k_add: NodeProto | None, + v_add: NodeProto | None, num_heads: int, hidden_size: int, output: str, @@ -547,7 +546,7 @@ def create_multihead_attention_node( present_k: str = "", present_v: str = "", packed_qkv: bool = False, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: """Create a MultiHeadAttention node. Args: @@ -647,7 +646,7 @@ def create_multihead_attention_node( def create_attention_node( self, - mask_index: Optional[str], + mask_index: str | None, q_matmul: NodeProto, k_matmul: NodeProto, v_matmul: NodeProto, @@ -663,9 +662,9 @@ def create_attention_node( past_v: str = "", present_k: str = "", present_v: str = "", - scale: Optional[float] = None, + scale: float | None = None, causal: bool = False, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: """Create an Attention node. Args: @@ -762,7 +761,7 @@ def create_attention_node( qkv_weight_dim = 3 * qw_out_size qkv_bias_dim = 0 - qkv_bias: Optional[np.ndarray] = None + qkv_bias: np.ndarray | None = None if has_bias: qb = NumpyHelper.to_array(q_bias) kb = NumpyHelper.to_array(k_bias) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index 16e2c36bfd092..a4a7a5c8c1890 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple from fusion_attention import AttentionMask, FusionAttention from fusion_options import AttentionMaskFormat @@ -36,7 +35,7 @@ def __init__( search_op_types=["SkipLayerNormalization"], ) - def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]: """Detect num_heads and hidden_size for ONNX model from MiDaS Args: reshape_q (NodeProto): reshape node for q diff --git a/onnxruntime/python/tools/transformers/fusion_attention_sam2.py b/onnxruntime/python/tools/transformers/fusion_attention_sam2.py index ce7ddd3c1050e..f66d7d12d1e5f 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_sam2.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_sam2.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple, Union import numpy as np from fusion_base import Fusion @@ -97,7 +96,7 @@ def get_hidden_size(self, layernorm_node): def get_num_heads_and_hidden_size( self, reshape_q: NodeProto, layernorm_node: NodeProto, is_encoder: bool = False - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """Detect num_heads and hidden_size. Args: @@ -142,7 +141,7 @@ def create_attention_node( num_heads: int, hidden_size: int, output: str, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: """Create an Attention node. Args: diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 9a353e7e2d675..1bdf4f24f3621 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple, Union import numpy as np from fusion_base import Fusion @@ -91,7 +90,7 @@ def get_hidden_size(self, layernorm_node): def get_num_heads_and_hidden_size( self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """Detect num_heads and hidden_size. Args: @@ -132,7 +131,7 @@ def create_attention_node( hidden_size: int, input: str, output: str, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: """Create an Attention node. Args: @@ -390,7 +389,7 @@ def create_attention_node_lora( hidden_size: int, input: str, output: str, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: """Create an Attention node. Args: diff --git a/onnxruntime/python/tools/transformers/fusion_attention_vae.py b/onnxruntime/python/tools/transformers/fusion_attention_vae.py index 151c04f9334fe..2b57fa2c418cf 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_vae.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_vae.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple, Union import numpy as np from fusion_base import Fusion @@ -27,7 +26,7 @@ def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int): self.num_heads_warning = True self.hidden_size_warning = True - def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> Tuple[int, int]: + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> tuple[int, int]: """Detect num_heads and hidden_size from a reshape node. Args: @@ -80,7 +79,7 @@ def create_attention_node( hidden_size: int, input_name: str, output_name: str, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: """Create an Attention node. Args: diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index 67f4f0b55cff8..a923e14c493f4 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -3,8 +3,9 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from collections import defaultdict +from collections.abc import Sequence from logging import getLogger -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any import numpy as np from onnx import NodeProto, helper @@ -22,18 +23,18 @@ def __init__( self, model: OnnxModel, fused_op_type: str, - search_op_types: Union[str, List[str]], + search_op_types: str | list[str], description: str = "", ): - self.search_op_types: List[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types + self.search_op_types: list[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types self.fused_op_type: str = fused_op_type self.description: str = f"{fused_op_type}({description})" if description else fused_op_type self.model: OnnxModel = model - self.nodes_to_remove: List = [] - self.nodes_to_add: List = [] + self.nodes_to_remove: list = [] + self.nodes_to_add: list = [] self.prune_graph: bool = False self.node_name_to_graph_name: dict = {} - self.this_graph_name: Optional[str] = None + self.this_graph_name: str | None = None # It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter. self.fused_count: defaultdict = defaultdict(int) @@ -46,8 +47,8 @@ def increase_counter(self, fused_op_name: str): def fuse( self, node: NodeProto, - input_name_to_nodes: Dict[str, List[NodeProto]], - output_name_to_node: Dict[str, NodeProto], + input_name_to_nodes: dict[str, list[NodeProto]], + output_name_to_node: dict[str, NodeProto], ): """Interface for fusion that starts from a node""" raise NotImplementedError @@ -114,7 +115,7 @@ def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: self.model.add_initializer(tensor, self.this_graph_name) return tensor - def add_nodes_to_remove(self, nodes: List[NodeProto]): + def add_nodes_to_remove(self, nodes: list[NodeProto]): # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths). # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are @@ -131,7 +132,7 @@ def add_nodes_to_remove(self, nodes: List[NodeProto]): if node not in self.nodes_to_remove: self.nodes_to_remove.append(node) - def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]): + def add_nodes_to_remove_with_nodes_to_keep(self, nodes: list[NodeProto], nodes_to_keep: list[NodeProto]): for node in nodes: if node not in self.nodes_to_remove and node not in nodes_to_keep: self.nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/fusion_bias_add.py b/onnxruntime/python/tools/transformers/fusion_bias_add.py index 8489af0940983..1cb4edad04ffe 100644 --- a/onnxruntime/python/tools/transformers/fusion_bias_add.py +++ b/onnxruntime/python/tools/transformers/fusion_bias_add.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict from fusion_base import Fusion from numpy import ndarray @@ -17,7 +16,7 @@ class FusionBiasAdd(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "BiasAdd", "Add") - def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, add_node, input_name_to_nodes: dict, output_name_to_node: dict): """ Fuse Add bias and Add skip connection into BiasAdd """ diff --git a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py index 67a7c0fb9ceb3..1118809fdf6d3 100644 --- a/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_biassplitgelu.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict from fusion_base import Fusion from onnx import helper @@ -16,7 +15,7 @@ class FusionBiasSplitGelu(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "BiasSplitGelu", "Gelu") - def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, gelu_node, input_name_to_nodes: dict, output_name_to_node: dict): """ [root] --->Add --------------------> Slice ---------------> Mul --> | ^ ^ diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index 70ff57f0626e1..66ef06097aa58 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, List, Optional, Tuple, Union from fusion_base import Fusion from fusion_utils import FusionUtils @@ -35,7 +34,7 @@ def __init__(self, model: OnnxModel, description: str = "no mask"): self.attention = None self.embed_node = None - def match_two_gather(self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeProto]]: + def match_two_gather(self, add: NodeProto) -> None | tuple[NodeProto, NodeProto]: gather_0_path = self.model.match_parent_path(add, ["Gather"], [0]) if gather_0_path is None: return None @@ -49,7 +48,7 @@ def match_two_gather(self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeP def check_attention_subgraph( self, layernorm: NodeProto, - input_name_to_nodes: Dict[str, List[NodeProto]], + input_name_to_nodes: dict[str, list[NodeProto]], is_distil_bert: bool, ) -> bool: """Check that LayerNormalization has a child of Attention node or subgraph like Attention. @@ -399,7 +398,7 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit return True - def cast_to_int32(self, input_name: str) -> Tuple[str, Union[None, NodeProto]]: + def cast_to_int32(self, input_name: str) -> tuple[str, None | NodeProto]: """Cast a graph input or node input to int32. Args: @@ -428,8 +427,8 @@ def create_fused_node( layernorm: NodeProto, word_embedding_gather: NodeProto, position_embedding_gather: NodeProto, - segment_embedding_gather: Union[None, NodeProto], - position_ids: Optional[str] = None, + segment_embedding_gather: None | NodeProto, + position_ids: str | None = None, embedding_sum_output=False, embedding_sum_name=None, ): diff --git a/onnxruntime/python/tools/transformers/fusion_fastgelu.py b/onnxruntime/python/tools/transformers/fusion_fastgelu.py index e2bb8027c8608..99f716193adb6 100644 --- a/onnxruntime/python/tools/transformers/fusion_fastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_fastgelu.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, Optional from fusion_base import Fusion from onnx import helper @@ -16,7 +15,7 @@ class FusionFastGelu(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "FastGelu", "Tanh") - def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict): if self.fuse_1(tanh_node, input_name_to_nodes, output_name_to_node): return @@ -29,7 +28,7 @@ def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict): if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node): return - def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]: + def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> bool | None: """ Fuse Gelu with tanh into one node: +---------------------------+ @@ -137,7 +136,7 @@ def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optiona self.node_name_to_graph_name[fused_node.name] = self.this_graph_name return True - def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: + def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None: """ This pattern is from Tensorflow model. Fuse Gelu with tanh into one node: @@ -246,7 +245,7 @@ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict self.node_name_to_graph_name[fused_node.name] = self.this_graph_name return True - def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: + def fuse_3(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None: """ OpenAI's gelu implementation, also used in Megatron: Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x))) @@ -362,7 +361,7 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict self.node_name_to_graph_name[fused_node.name] = self.this_graph_name return True - def fuse_4(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: + def fuse_4(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None: """ This pattern is from stable diffusion 3.5 model. Fuse Gelu with tanh into one node: diff --git a/onnxruntime/python/tools/transformers/fusion_gelu.py b/onnxruntime/python/tools/transformers/fusion_gelu.py index 6be5140c070d0..12f7d82a9c0af 100644 --- a/onnxruntime/python/tools/transformers/fusion_gelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gelu.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, Optional from fusion_base import Fusion from onnx import helper @@ -16,14 +15,14 @@ class FusionGelu(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "Gelu", "Erf") - def fuse(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, erf_node, input_name_to_nodes: dict, output_name_to_node: dict): if self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node): return if self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node): return self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node) - def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: + def fuse_1(self, erf_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None: """ This pattern is from PyTorch model Fuse Gelu with Erf into one node: @@ -107,7 +106,7 @@ def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) self.increase_counter("Gelu") return True - def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: + def fuse_2(self, erf_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None: """ This pattern is from Keras model Fuse Gelu with Erf into one node: @@ -184,7 +183,7 @@ def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) self.increase_counter("Gelu") return True - def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: + def fuse_3(self, erf_node, input_name_to_nodes: dict, output_name_to_node: dict) -> bool | None: """ This pattern is from TensorFlow model Fuse Gelu with Erf into one node: diff --git a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py index 4d9913f427b37..23eee1413ff9f 100644 --- a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, List, Union from fusion_base import Fusion from fusion_utils import NumpyHelper @@ -20,13 +19,13 @@ def __init__(self, model: OnnxModel): self.shape_infer = None self.shape_infer_done = False - def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]: + def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> int | None: if tensor_proto.type.tensor_type.HasField("shape"): return len(tensor_proto.type.tensor_type.shape.dim) else: return None - def get_dimensions(self, input_name: str) -> Union[int, None]: + def get_dimensions(self, input_name: str) -> int | None: graph_input = self.model.find_graph_input(input_name) if graph_input: return self.get_dimensions_from_tensor_proto(graph_input) @@ -43,8 +42,8 @@ def get_dimensions(self, input_name: str) -> Union[int, None]: def fuse( self, node: NodeProto, - input_name_to_nodes: Dict[str, List[NodeProto]], - output_name_to_node: Dict[str, NodeProto], + input_name_to_nodes: dict[str, list[NodeProto]], + output_name_to_node: dict[str, NodeProto], ): """ This pattern is from PyTorch bert model diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index c9bf52234d696..2efec3e6ac6e8 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict import numpy as np from fusion_base import Fusion @@ -18,7 +17,7 @@ def __init__(self, model: OnnxModel, channels_last=True): super().__init__(model, "GroupNorm", "Add") self.channels_last = channels_last - def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, add_node, input_name_to_nodes: dict, output_name_to_node: dict): """ Fuse Group Normalization subgraph into one node GroupNorm. The following is the pattern with swish activation: diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index 277bd0799cf16..1c96c54d9de35 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, List from fusion_base import Fusion from onnx import TensorProto, helper @@ -18,7 +17,7 @@ def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True, self.check_constant_and_dimension = check_constant_and_dimension self.force = force - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): """ Fuse Layer Normalization subgraph into one node LayerNormalization: +----------------------+ @@ -184,7 +183,7 @@ def get_weight_or_bias(self, output_name, description): return value.reshape([value.shape[0]]) - def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + def create_transpose_node(self, input_name: str, perm: list[int], output_name=None): """Append a Transpose node after an input""" node_name = self.model.create_node_name("Transpose") @@ -196,7 +195,7 @@ def create_transpose_node(self, input_name: str, perm: List[int], output_name=No return transpose_node - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): """ Fuse Layer Normalization subgraph into one node LayerNormalization: +----------------------+ @@ -328,7 +327,7 @@ class FusionLayerNormalizationTF(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "LayerNormalization", "Add", "TF") - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): """ Layer Norm from Tensorflow model(using keras2onnx or tf2onnx): +------------------------------------+ diff --git a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py index dcad55c13eb49..48f6f9a9686ee 100644 --- a/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py +++ b/onnxruntime/python/tools/transformers/fusion_mha_mmdit.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, Optional import numpy as np from fusion_base import Fusion @@ -131,7 +130,7 @@ def reshape_to_3d(self, input_name: str, output_name: str) -> str: self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name return reshape_q.output[0] - def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> str | None: """ MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format. @@ -179,7 +178,7 @@ def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_ return self.reshape_to_3d(sln_a.output[0], sln_output + "_BSD") - def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None: """ MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format. @@ -294,7 +293,7 @@ def update_unsqueeze_axes_1_to_2(self, unsqueeze: NodeProto) -> str: return updated_unsqueeze_output - def update_unsqueeze_axes(self, add: NodeProto, output_name_to_node: Dict[str, NodeProto]) -> bool: + def update_unsqueeze_axes(self, add: NodeProto, output_name_to_node: dict[str, NodeProto]) -> bool: """ Update axes of Unsqueeze from [1] to [2] in the following pattern: Unsqueeze Unsqueeze @@ -347,7 +346,7 @@ def update_unsqueeze_axes(self, add: NodeProto, output_name_to_node: Dict[str, N nodes_b[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_b[1]) return True - def adjust_flux_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + def adjust_flux_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None: """ Adjust graph to change query format from BNSH to BSD for Flux model. Note that the graph pattern is complex, and we only do a shallow match here. @@ -431,7 +430,7 @@ def adjust_flux_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_no return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD") - def adjust_flux_single_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]: + def adjust_flux_single_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None: """ Adjust graph to change query format from BNSH to BSD for Flux model. Note that the graph pattern is complex, and we only do a shallow match here. @@ -482,7 +481,7 @@ def adjust_flux_single_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_nam return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD") - def transpose_reshape_bnsh_to_bsd(self, q: str, output_name_to_node) -> Optional[str]: + def transpose_reshape_bnsh_to_bsd(self, q: str, output_name_to_node) -> str | None: transpose_q = helper.make_node( "Transpose", [q], diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py index 5233fdf272fbd..0ad50a270caf7 100644 --- a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import List from fusion_base import Fusion from fusion_utils import FusionUtils @@ -22,7 +21,7 @@ def __init__(self, model: OnnxModel, update_weight=False): self.update_weight = update_weight self.fusion_utils = FusionUtils(model) - def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + def create_transpose_node(self, input_name: str, perm: list[int], output_name=None): """Append a Transpose node after an input""" node_name = self.model.create_node_name("Transpose") diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_attention.py b/onnxruntime/python/tools/transformers/fusion_qordered_attention.py index fb020298bc210..52ccfc6fe368d 100644 --- a/onnxruntime/python/tools/transformers/fusion_qordered_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_qordered_attention.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple import numpy as np from fusion_attention import AttentionMask @@ -30,7 +29,7 @@ def __init__( super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization") - def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]: """Detect num_heads and hidden_size from a reshape node. Args: reshape_q (NodeProto): reshape node for Q diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py b/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py index 5f395b364eb6f..6a6b52a988c00 100644 --- a/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py +++ b/onnxruntime/python/tools/transformers/fusion_qordered_gelu.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict from fusion_base import Fusion from fusion_utils import FusionUtils @@ -18,7 +17,7 @@ class FusionQOrderedGelu(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "QOrderedGelu", ["Gelu", "FastGelu"]) - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): """ INPUT PATTERN Fuse (quantized) Gelu subgraph into one node QOrderedGelu: diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py b/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py index 5ec6dadc1e677..c8b1be71d4616 100644 --- a/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_qordered_layernorm.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict from fusion_base import Fusion from fusion_utils import FusionUtils @@ -17,7 +16,7 @@ class FusionQOrderedLayerNormalization(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "QOrderedLayerNormalization", "LayerNormalization") - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): """ Fuse (quantized) Layer Normalization subgraph into one node QOrderedLayerNormalization: quantized input -> DQ diff --git a/onnxruntime/python/tools/transformers/fusion_qordered_matmul.py b/onnxruntime/python/tools/transformers/fusion_qordered_matmul.py index 681160479faef..3a373f3fd4d78 100644 --- a/onnxruntime/python/tools/transformers/fusion_qordered_matmul.py +++ b/onnxruntime/python/tools/transformers/fusion_qordered_matmul.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict from fusion_base import Fusion from fusion_utils import FusionUtils @@ -18,7 +17,7 @@ class FusionQOrderedMatMul(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "QOrderedMatMul", "MatMul") - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): matmul_children = self.model.get_children(node, input_name_to_nodes) # Should only have 1 child - Bias Add diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index efdcbcfb3dcdc..6657fde2257e5 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Optional, Union from fusion_attention import FusionAttention from fusion_base import Fusion @@ -51,8 +50,8 @@ def create_mha_node( past_v: str = "", present_k: str = "", present_v: str = "", - scale: Optional[float] = None, - ) -> Union[NodeProto, None]: + scale: float | None = None, + ) -> NodeProto | None: assert self.num_heads > 0 if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0: @@ -1131,7 +1130,7 @@ def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProt extra_initializers.append(constant_tensorproto.name) # Update references of Constant node outputs to initializer references - for extra_output, extra_initializer in zip(extra_outputs, extra_initializers): + for extra_output, extra_initializer in zip(extra_outputs, extra_initializers, strict=False): nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node)) for node_to_update in nodes_to_update: OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer) diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index dfa77fc7d0221..18a8fda6a67b1 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, List, Union from fusion_base import Fusion from fusion_utils import FusionUtils @@ -22,13 +21,13 @@ def __init__(self, model: OnnxModel): self.shape_infer = None self.shape_infer_done = False - def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]: + def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> int | None: if tensor_proto.type.tensor_type.HasField("shape"): return len(tensor_proto.type.tensor_type.shape.dim) else: return None - def get_dimensions(self, input_name: str) -> Union[int, None]: + def get_dimensions(self, input_name: str) -> int | None: shape = self.model.get_shape(input_name) if shape is not None: return len(shape) @@ -45,8 +44,8 @@ def get_dimensions(self, input_name: str) -> Union[int, None]: def fuse( self, concat_node: NodeProto, - input_name_to_nodes: Dict[str, List[NodeProto]], - output_name_to_node: Dict[str, NodeProto], + input_name_to_nodes: dict[str, list[NodeProto]], + output_name_to_node: dict[str, NodeProto], ): # # Simplify subgraph like diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py index ca7ff6462b9ff..a0eff081675fe 100644 --- a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -1,5 +1,4 @@ import logging -from typing import Dict from fusion_base import Fusion from fusion_skiplayernorm import FusionSkipLayerNormalization @@ -13,7 +12,7 @@ class FusionSimplifiedLayerNormalization(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "SimplifiedLayerNormalization", "Mul") - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict): if node.op_type != "Mul": return diff --git a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py index 676052f747967..b2b3af38253c2 100644 --- a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import List from fusion_base import Fusion from fusion_utils import NumpyHelper @@ -26,7 +25,7 @@ def __init__(self, model: OnnxModel): if self.shape_infer_helper is None: logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.") - def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + def create_transpose_node(self, input_name: str, perm: list[int], output_name=None): """Append a Transpose node after an input""" node_name = self.model.create_node_name("Transpose") if output_name is None: diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py index ca699903a7cd9..d38fcffb2af0d 100644 --- a/onnxruntime/python/tools/transformers/fusion_transpose.py +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import Dict, List from fusion_base import Fusion from fusion_utils import FusionUtils @@ -21,8 +20,8 @@ def __init__(self, model: OnnxModel): def fuse( self, transpose_node: NodeProto, - input_name_to_nodes: Dict[str, List[NodeProto]], - output_name_to_node: Dict[str, NodeProto], + input_name_to_nodes: dict[str, list[NodeProto]], + output_name_to_node: dict[str, NodeProto], ): """ Note that onnxruntime will do comprehensive transpose optimization after loading model. @@ -90,7 +89,7 @@ class FusionInsertTranspose(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "", "GroupNorm") - def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + def create_transpose_node(self, input_name: str, perm: list[int], output_name=None): """Append a Transpose node after an input""" node_name = self.model.create_node_name("Transpose") if output_name is None: @@ -102,8 +101,8 @@ def create_transpose_node(self, input_name: str, perm: List[int], output_name=No def fuse( self, group_norm_node: NodeProto, - input_name_to_nodes: Dict[str, List[NodeProto]], - output_name_to_node: Dict[str, NodeProto], + input_name_to_nodes: dict[str, list[NodeProto]], + output_name_to_node: dict[str, NodeProto], ): """ This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index 3084b84278994..5343c77adb97a 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Optional, Tuple import numpy from numpy import array_equal, ndarray @@ -18,7 +17,7 @@ class FusionUtils: def __init__(self, model: OnnxModel): self.model: OnnxModel = model - def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: + def cast_graph_input_to_int32(self, input_name: str) -> tuple[bool, str]: graph_input = self.model.find_graph_input(input_name) if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32: cast_output, cast_node = self.cast_input_to_int32(input_name) @@ -48,9 +47,9 @@ def add_cast_node( self, input_name: str, to_type: int, - output_name: Optional[str] = None, + output_name: str | None = None, output_name_to_node=None, - graph_name: Optional[str] = None, + graph_name: str | None = None, ): if output_name is None: output_name = input_name + f"_cast_to_{to_type}" @@ -127,7 +126,7 @@ def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_i return parent_can_be_removed - def get_squeeze_or_unsqueeze_axes(self, node: NodeProto) -> Optional[ndarray]: + def get_squeeze_or_unsqueeze_axes(self, node: NodeProto) -> ndarray | None: assert node.op_type in ["Squeeze", "Unsqueeze"] # For opset >= 13, axes is an input instead of an attribute. diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 0fa038d5cfc62..5870a031086ee 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,7 +1,8 @@ import copy import logging from collections import OrderedDict -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from collections.abc import Mapping +from typing import Any import numpy import torch @@ -9,7 +10,7 @@ from onnxruntime import InferenceSession, RunOptions # Type alias -ShapeDict = Mapping[str, Union[Tuple, List[int]]] +ShapeDict = Mapping[str, tuple | list[int]] logger = logging.getLogger(__name__) @@ -88,7 +89,7 @@ def torch_type_to_numpy_type(torch_type: torch.dtype): return torch_type_to_numpy_type_map[torch_type] @staticmethod - def get_io_numpy_type_map(ort_session: InferenceSession) -> Dict[str, numpy.dtype]: + def get_io_numpy_type_map(ort_session: InferenceSession) -> dict[str, numpy.dtype]: """Create a mapping from input/output name to numpy data type""" name_to_numpy_type = {} for input in ort_session.get_inputs(): @@ -116,7 +117,7 @@ def prepare_io_binding( input_ids: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, - past: List[torch.Tensor], + past: list[torch.Tensor], output_buffers, output_shapes, name_to_np_type=None, @@ -228,7 +229,7 @@ def __init__(self, ort_session: InferenceSession, device: torch.device, enable_c self.device = device # Pairs of input and output names that share the same buffer. - self.buffer_sharing: Dict[str, str] = {} + self.buffer_sharing: dict[str, str] = {} def set_buffer_sharing(self, input_name: str, output_name: str): assert input_name in self.input_names @@ -307,7 +308,7 @@ def allocate_buffers(self, shape_dict: ShapeDict): tensor.data_ptr(), ) - def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True): + def infer(self, feed_dict: dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True): """Bind input tensors and run inference""" for name, tensor in feed_dict.items(): assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() @@ -330,7 +331,7 @@ def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = No return self.output_tensors @staticmethod - def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool, stream: int = 0) -> Dict[str, Any]: + def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool, stream: int = 0) -> dict[str, Any]: options = { "device_id": device_id, "arena_extend_strategy": "kSameAsRequested", @@ -353,7 +354,7 @@ def __init__( enable_gpu_graph: bool = False, gpu_graph_id: int = -1, stream: int = 0, - buffer_sharing: Optional[Dict[str, str]] = None, + buffer_sharing: dict[str, str] | None = None, ): super().__init__(ort_session, device, enable_gpu_graph) if buffer_sharing: @@ -379,7 +380,7 @@ def get_run_options(self, disable_cuda_graph_in_run: bool = False) -> RunOptions return options - def infer(self, feed_dict: Dict[str, torch.Tensor], disable_cuda_graph_in_run: bool = False): + def infer(self, feed_dict: dict[str, torch.Tensor], disable_cuda_graph_in_run: bool = False): run_options = self.get_run_options(disable_cuda_graph_in_run) if self.stream: @@ -411,7 +412,7 @@ def get_binding( self, shape_dict: ShapeDict, use_cuda_graph: bool = False, - buffer_sharing: Optional[Dict[str, str]] = None, + buffer_sharing: dict[str, str] | None = None, ) -> GpuBinding: for gpu_graph_binding in self.graph_bindings: # Found a cuda graph that captured with the same shape diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index f623102802a67..29829a6c475d9 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -13,7 +13,6 @@ import os import tempfile from pathlib import Path -from typing import Optional import onnx import torch @@ -50,7 +49,7 @@ def get_model_parameter_size(model: nn.Module): return all_size -def initialize_model_and_sample_inputs(hf_model: str, cache_dir: Optional[str], tokenizer=None): +def initialize_model_and_sample_inputs(hf_model: str, cache_dir: str | None, tokenizer=None): """ get the pretrained torch model from hugginface, and sample model-inputs @@ -155,7 +154,7 @@ def hook_for_inputs(_, inputs, kwargs): for key, value in user_inputs[1].items(): idx = input_keys.index(key) onnx_inputs[idx] = value - for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): + for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs, strict=False)): if type(value) is torch.Tensor: value.to(model.device) if "use_cache" in key: @@ -309,7 +308,7 @@ def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tupl @torch.no_grad() -def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): +def export_onnx(hf_model: str, cache_dir: str | None, onnx_path_str: str, with_past: bool, opset: int): """ do export model: torch model diff --git a/onnxruntime/python/tools/transformers/machine_info.py b/onnxruntime/python/tools/transformers/machine_info.py index d4194abbd14d3..7f9a0110bcd9f 100644 --- a/onnxruntime/python/tools/transformers/machine_info.py +++ b/onnxruntime/python/tools/transformers/machine_info.py @@ -10,7 +10,6 @@ import logging import platform from os import environ -from typing import Dict, List import cpuinfo import psutil @@ -66,12 +65,12 @@ def get_machine_info(self): } return machine_info - def get_memory_info(self) -> Dict: + def get_memory_info(self) -> dict: """Get memory info""" mem = psutil.virtual_memory() return {"total": mem.total, "available": mem.available} - def _try_get(self, cpu_info: Dict, names: List) -> str: + def _try_get(self, cpu_info: dict, names: list) -> str: for name in names: if name in cpu_info: value = cpu_info[name] @@ -80,7 +79,7 @@ def _try_get(self, cpu_info: Dict, names: List) -> str: return value return "" - def get_cpu_info(self) -> Dict: + def get_cpu_info(self) -> dict: """Get CPU info""" cpu_info = cpuinfo.get_cpu_info() @@ -94,7 +93,7 @@ def get_cpu_info(self) -> Dict: "processor": platform.uname().processor, } - def get_gpu_info_by_nvml(self) -> Dict: + def get_gpu_info_by_nvml(self) -> dict: """Get GPU info using nvml""" gpu_info_list = [] driver_version = None @@ -122,7 +121,7 @@ def get_gpu_info_by_nvml(self) -> Dict: result["cuda_visible"] = environ["CUDA_VISIBLE_DEVICES"] return result - def get_related_packages(self) -> List[str]: + def get_related_packages(self) -> list[str]: import pkg_resources installed_packages = pkg_resources.working_set @@ -142,7 +141,7 @@ def get_related_packages(self) -> List[str]: related_packages_list = {i.key: i.version for i in installed_packages if i.key in related_packages} return related_packages_list - def get_onnxruntime_info(self) -> Dict: + def get_onnxruntime_info(self) -> dict: try: import onnxruntime @@ -159,7 +158,7 @@ def get_onnxruntime_info(self) -> Dict: self.logger.exception(exception, False) return None - def get_pytorch_info(self) -> Dict: + def get_pytorch_info(self) -> dict: try: import torch @@ -177,7 +176,7 @@ def get_pytorch_info(self) -> Dict: self.logger.exception(exception, False) return None - def get_tensorflow_info(self) -> Dict: + def get_tensorflow_info(self) -> dict: try: import tensorflow as tf diff --git a/onnxruntime/python/tools/transformers/metrics.py b/onnxruntime/python/tools/transformers/metrics.py index 282c75ba8f6a5..74a34df28c019 100644 --- a/onnxruntime/python/tools/transformers/metrics.py +++ b/onnxruntime/python/tools/transformers/metrics.py @@ -6,7 +6,6 @@ import datetime import json -from typing import Optional import pandas as pd @@ -30,10 +29,10 @@ def to_dict(self): class ModelInfo(BaseObject): def __init__( self, - full_name: Optional[str] = None, - is_huggingface: Optional[bool] = False, - is_text_generation: Optional[bool] = False, - short_name: Optional[str] = None, + full_name: str | None = None, + is_huggingface: bool | None = False, + is_text_generation: bool | None = False, + short_name: str | None = None, ): super().__init__() self.full_name = full_name @@ -46,9 +45,9 @@ def __init__( class BackendOptions(BaseObject): def __init__( self, - enable_profiling: Optional[bool] = False, - execution_provider: Optional[str] = None, - use_io_binding: Optional[bool] = False, + enable_profiling: bool | None = False, + execution_provider: str | None = None, + use_io_binding: bool | None = False, ): super().__init__() self.enable_profiling = enable_profiling @@ -59,12 +58,12 @@ def __init__( class Config(BaseObject): def __init__( self, - backend: Optional[str] = "onnxruntime", - batch_size: Optional[int] = 1, - seq_length: Optional[int] = 0, - precision: Optional[str] = "fp32", - warmup_runs: Optional[int] = 1, - measured_runs: Optional[int] = 10, + backend: str | None = "onnxruntime", + batch_size: int | None = 1, + seq_length: int | None = 0, + precision: str | None = "fp32", + warmup_runs: int | None = 1, + measured_runs: int | None = 10, ): super().__init__() self.backend = backend @@ -80,11 +79,11 @@ def __init__( class Metadata(BaseObject): def __init__( self, - device: Optional[str] = None, - package_name: Optional[str] = None, - package_version: Optional[str] = None, - platform: Optional[str] = None, - python_version: Optional[str] = None, + device: str | None = None, + package_name: str | None = None, + package_version: str | None = None, + platform: str | None = None, + python_version: str | None = None, ): super().__init__() self.device = device @@ -97,9 +96,9 @@ def __init__( class Metrics(BaseObject): def __init__( self, - latency_ms_mean: Optional[float] = 0.0, - throughput_qps: Optional[float] = 0.0, - max_memory_usage_GB: Optional[float] = 0.0, + latency_ms_mean: float | None = 0.0, + throughput_qps: float | None = 0.0, + max_memory_usage_GB: float | None = 0.0, ): super().__init__() self.latency_ms_mean = latency_ms_mean @@ -116,10 +115,10 @@ def __init__( device: str, package_name: str, package_version: str, - batch_size: Optional[int] = 1, - warmup_runs: Optional[int] = 1, - measured_runs: Optional[int] = 10, - trigger_date: Optional[str] = None, + batch_size: int | None = 1, + warmup_runs: int | None = 1, + measured_runs: int | None = 10, + trigger_date: str | None = None, ): self.config = Config() self.metrics = Metrics() diff --git a/onnxruntime/python/tools/transformers/models/bart/utils/export_helper.py b/onnxruntime/python/tools/transformers/models/bart/utils/export_helper.py index 8b7c18dbde7d9..85d2fa9a64e23 100644 --- a/onnxruntime/python/tools/transformers/models/bart/utils/export_helper.py +++ b/onnxruntime/python/tools/transformers/models/bart/utils/export_helper.py @@ -4,13 +4,12 @@ # license information. # -------------------------------------------------------------------------- -from typing import List, Tuple import torch from transformers import BartConfig, BartForConditionalGeneration, BartTokenizer -def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: bool = False): +def group_by_self_and_cross(present_key_values: tuple[torch.Tensor], concat: bool = False): """Categorize present_key_values into self and cross attention. Split present state from grouped by layer to grouped by self/cross attention. @@ -27,8 +26,8 @@ def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: boo present_self (Tuple[torch.Tensor]): present key and values from self attention present_cross (Tuple[torch.Tensor]): present key and values from cross attention """ - present_self: List[torch.Tensor] = [] - present_cross: List[torch.Tensor] = [] + present_self: list[torch.Tensor] = [] + present_cross: list[torch.Tensor] = [] for _, present_layer_i in enumerate(present_key_values): assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}" present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i @@ -40,7 +39,7 @@ def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: boo return present_self, present_cross -def back_group_by_layer(past_key_values: Tuple[Tuple[torch.Tensor]]): +def back_group_by_layer(past_key_values: tuple[tuple[torch.Tensor]]): """Categorize present_key_values from self and cross attention to layer by layer. Reorder past state from grouped by self/cross attention to grouped by layer. @@ -70,7 +69,7 @@ def back_group_by_layer(past_key_values: Tuple[Tuple[torch.Tensor]]): return past_tuples -def get_input_names(past_key_values: Tuple[Tuple[torch.Tensor]], encoder=True): +def get_input_names(past_key_values: tuple[tuple[torch.Tensor]], encoder=True): """Process input names of model wrapper. Args: @@ -89,7 +88,7 @@ def get_input_names(past_key_values: Tuple[Tuple[torch.Tensor]], encoder=True): return names -def get_output_names(past_key_values: Tuple[torch.Tensor]): +def get_output_names(past_key_values: tuple[torch.Tensor]): """Process output names of model wrapper. As cross attention is unchanged during every iteration of beam search, diff --git a/onnxruntime/python/tools/transformers/models/bart/utils/export_summarization_edinit.py b/onnxruntime/python/tools/transformers/models/bart/utils/export_summarization_edinit.py index 8a610fb17671b..f8d13ca041349 100644 --- a/onnxruntime/python/tools/transformers/models/bart/utils/export_summarization_edinit.py +++ b/onnxruntime/python/tools/transformers/models/bart/utils/export_summarization_edinit.py @@ -6,7 +6,7 @@ import os import time -from typing import Any, Dict, Optional +from typing import Any import torch from transformers import BartConfig, BartForConditionalGeneration, file_utils @@ -87,8 +87,8 @@ def _create_encoder_export(args, config: BartConfig): """ def _prepare_encoder_decoder_kwargs_for_generation( - self, input_ids: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None - ) -> Dict[str, Any]: + self, input_ids: torch.Tensor, model_kwargs, model_input_name: str | None = None + ) -> dict[str, Any]: # retrieve encoder hidden states # 1. get encoder encoder = self.get_encoder() diff --git a/onnxruntime/python/tools/transformers/models/bart/utils/export_summarization_enc_dec_past.py b/onnxruntime/python/tools/transformers/models/bart/utils/export_summarization_enc_dec_past.py index afd01ae9d025f..475e4c5aecd18 100644 --- a/onnxruntime/python/tools/transformers/models/bart/utils/export_summarization_enc_dec_past.py +++ b/onnxruntime/python/tools/transformers/models/bart/utils/export_summarization_enc_dec_past.py @@ -208,7 +208,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwarg # Test the generated model with onnxruntime print("========== ORT inference test on Decoder ... ==========") - ort_inputs = {name: value.cpu().numpy() for name, value in zip(input_names, inputs)} + ort_inputs = {name: value.cpu().numpy() for name, value in zip(input_names, inputs, strict=False)} # NOTE: encoder_hidden_states is not used and deleted ort_inputs.pop("encoder_hidden_states") sess_options = SessionOptions() @@ -216,7 +216,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwarg sess = InferenceSession(onnx_model_path, sess_options, providers=["CPUExecutionProvider"]) out = sess.run(None, ort_inputs) - for ort_out, torch_out in zip(out, [logits, *present]): + for ort_out, torch_out in zip(out, [logits, *present], strict=False): torch.testing.assert_close(ort_out, torch_out.cpu().numpy(), check_dtype=True, atol=1e-4, rtol=1e-2) print("========== [SUCCESS] ORT inference test on Decoder ==========") diff --git a/onnxruntime/python/tools/transformers/models/bert/eval_squad.py b/onnxruntime/python/tools/transformers/models/bert/eval_squad.py index 8797fd9c2cfaf..680b3455ade2d 100644 --- a/onnxruntime/python/tools/transformers/models/bert/eval_squad.py +++ b/onnxruntime/python/tools/transformers/models/bert/eval_squad.py @@ -33,7 +33,7 @@ from importlib_metadata import PackageNotFoundError, version from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any from datasets import load_dataset from evaluate import evaluator @@ -60,7 +60,7 @@ def get_package_version(package_name: str): def load_onnx_model( - model_id: str, onnx_path: Optional[str] = None, provider="CUDAExecutionProvider", use_io_binding: bool = False + model_id: str, onnx_path: str | None = None, provider="CUDAExecutionProvider", use_io_binding: bool = False ): """Load onnx model given pretrained model name and optional ONNX model path. If onnx_path is None, the default onnx model from optimum will be used. @@ -95,7 +95,7 @@ def load_onnx_model( return model, onnx_path -def output_details(results: List[Dict[str, Any]], csv_filename: str): +def output_details(results: list[dict[str, Any]], csv_filename: str): """Output a CSV file with detail of each test results. Args: @@ -136,7 +136,7 @@ def output_details(results: List[Dict[str, Any]], csv_filename: str): print(f"Detail results are saved to csv file: {csv_filename}") -def output_summary(results: List[Dict[str, Any]], csv_filename: str, metric_name: str): +def output_summary(results: list[dict[str, Any]], csv_filename: str, metric_name: str): """Output a CSV file with summary of a metric on combinations of batch_size and sequence_length. Args: diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py index 1b12fe9005175..b405c19b04689 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py @@ -12,7 +12,6 @@ import tempfile import time from pathlib import Path -from typing import Dict, List, Tuple, Union import numpy import onnx @@ -139,17 +138,17 @@ class Gpt2Inputs: def __init__(self, input_ids, position_ids, attention_mask, past): self.input_ids: torch.LongTensor = input_ids self.position_ids: torch.LongTensor = position_ids - self.attention_mask: Union[torch.LongTensor, torch.FloatTensor, torch.HalfTensor] = attention_mask - self.past: Union[List[torch.FloatTensor], List[torch.HalfTensor]] = past + self.attention_mask: torch.LongTensor | torch.FloatTensor | torch.HalfTensor = attention_mask + self.past: list[torch.FloatTensor] | list[torch.HalfTensor] = past - def to_list(self) -> List: + def to_list(self) -> list: input_list = [v for v in [self.input_ids, self.position_ids, self.attention_mask] if v is not None] if self.past: input_list.extend(self.past) return input_list - def to_tuple(self) -> Tuple: + def to_tuple(self) -> tuple: return tuple(v for v in [self.input_ids, self.position_ids, self.attention_mask, self.past] if v is not None) def to_fp32(self): @@ -241,7 +240,7 @@ def get_output_shapes( sequence_length: int, config: GPT2Config, model_class: str = "GPT2LMHeadModel", - ) -> Dict[str, List[int]]: + ) -> dict[str, list[int]]: """Returns a dictionary with output name as key, and shape as value.""" num_attention_heads = config.num_attention_heads hidden_size = config.hidden_size @@ -541,7 +540,7 @@ def optimize_onnx( @staticmethod def auto_mixed_precision( onnx_model: OnnxModel, - op_block_list: List[str] = [ # noqa: B006 + op_block_list: list[str] = [ # noqa: B006 "Add", "LayerNormalization", "SkipLayerNormalization", @@ -698,8 +697,8 @@ def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shape def onnxruntime_inference_with_binded_io( ort_session, inputs: Gpt2Inputs, - output_buffers: Dict[str, torch.Tensor], - output_shapes: Dict[str, List[int]], + output_buffers: dict[str, torch.Tensor], + output_shapes: dict[str, list[int]], total_runs: int = 0, return_numpy: bool = True, include_copy_output_latency: bool = False, diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 7bf8bcb82e59a..89fd613ecbbc2 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -427,7 +427,7 @@ def convert_to_float16(args: argparse.Namespace, old_paths: list[str], rank: int new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] logger.info("Converting to float16...") - for fp32_path, fp16_path in zip(old_paths, new_paths): + for fp32_path, fp16_path in zip(old_paths, new_paths, strict=False): if os.path.exists(fp32_path): model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) model.convert_float_to_float16(keep_io_types=False) @@ -867,7 +867,7 @@ def main(): # Run the optimizer script. logger.info("Optimizing models...") - for orig_path, opt_path in zip(old_paths, new_paths): + for orig_path, opt_path in zip(old_paths, new_paths, strict=False): if os.path.exists(orig_path): optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) @@ -912,7 +912,7 @@ def main(): ) logger.info("Quantizing to int8...") - for fp32_path, int8_path in zip(old_paths, new_paths): + for fp32_path, int8_path in zip(old_paths, new_paths, strict=False): if os.path.exists(fp32_path): ort_quantization.quantize_dynamic( fp32_path, @@ -952,7 +952,7 @@ def main(): ) new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - for fp_path, int4_path in zip(old_paths, new_paths): + for fp_path, int4_path in zip(old_paths, new_paths, strict=False): if os.path.exists(fp_path): model = onnx.load_model(fp_path, load_external_data=True) quant = MatMul4BitsQuantizer( diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py index 274d56df3f12c..c7e0e31765a4f 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py +++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py @@ -41,7 +41,7 @@ import traceback from concurrent.futures import ProcessPoolExecutor from datetime import datetime -from typing import Any, Dict, List +from typing import Any import benchmark_helper import numpy as np @@ -63,7 +63,7 @@ def test_torch_latency( global_lengths, test_times, num_threads, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: if num_threads > 0: torch.set_num_threads(num_threads) @@ -143,7 +143,7 @@ def test_ort_latency( use_compact_memory=False, use_half4=False, disable_parity=False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: results = [] for batch_size in batch_sizes: for sequence_length in sequence_lengths: @@ -250,7 +250,7 @@ def test_ort_memory( global_length, test_times, num_threads, -) -> Dict[str, Any]: +) -> dict[str, Any]: logger.info( f"Testing memory for model={onnx_model_path}, batch_size={batch_size}, sequence_length={sequence_length}, " f"global_length={global_length}, test_times={test_times}, num_threads={num_threads}" @@ -307,7 +307,7 @@ def find_onnx_model(model_name, onnx_dir="."): return onnx_model_path -def test_memory(args, device) -> Dict[str, Any]: +def test_memory(args, device) -> dict[str, Any]: if len(args.batch_sizes) > 1: raise RuntimeError("For memory test, only one batch_size (-b) is allowed.") if len(args.sequence_lengths) > 1: @@ -330,7 +330,7 @@ def test_memory(args, device) -> Dict[str, Any]: ) -def test_ort(args, device) -> List[Dict[str, Any]]: +def test_ort(args, device) -> list[dict[str, Any]]: model_name = args.model onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx @@ -385,7 +385,7 @@ def test_ort(args, device) -> List[Dict[str, Any]]: ) -def test_torch(args, device) -> List[Dict[str, Any]]: +def test_torch(args, device) -> list[dict[str, Any]]: model = load_torch_model(args.model, device) return test_torch_latency( device, @@ -399,7 +399,7 @@ def test_torch(args, device) -> List[Dict[str, Any]]: ) -def test_latency(args, device) -> List[Dict[str, Any]]: +def test_latency(args, device) -> list[dict[str, Any]]: if args.engine == "onnxruntime": return test_ort(args, device) @@ -550,7 +550,7 @@ def output_details(results, csv_filename): print(f"Detail results are saved to csv file: {csv_filename}") -def run(args) -> List[Dict[str, Any]]: +def run(args) -> list[dict[str, Any]]: torch.set_grad_enabled(False) # set random seed manually to get deterministic results @@ -565,7 +565,7 @@ def run(args) -> List[Dict[str, Any]]: return test_latency(args, device) -def launch_test(arguments) -> List[Dict[str, Any]]: +def launch_test(arguments) -> list[dict[str, Any]]: if not torch.cuda.is_available(): raise RuntimeError("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.") diff --git a/onnxruntime/python/tools/transformers/models/longformer/longformer_helper.py b/onnxruntime/python/tools/transformers/models/longformer/longformer_helper.py index 1794bf75b4e6f..08a2ba629fbc3 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/longformer_helper.py +++ b/onnxruntime/python/tools/transformers/models/longformer/longformer_helper.py @@ -6,7 +6,6 @@ # This script helps creating dummy inputs for Longformer model. import logging -from typing import Dict, List, Tuple, Union import numpy import torch @@ -23,16 +22,16 @@ class LongformerInputs: def __init__(self, input_ids, attention_mask, global_attention_mask): self.input_ids: torch.LongTensor = input_ids - self.attention_mask: Union[torch.FloatTensor, torch.HalfTensor] = attention_mask - self.global_attention_mask: Union[torch.FloatTensor, torch.HalfTensor] = global_attention_mask + self.attention_mask: torch.FloatTensor | torch.HalfTensor = attention_mask + self.global_attention_mask: torch.FloatTensor | torch.HalfTensor = global_attention_mask - def to_list(self) -> List: + def to_list(self) -> list: return [v for v in [self.input_ids, self.attention_mask, self.global_attention_mask] if v is not None] - def to_tuple(self) -> Tuple: + def to_tuple(self) -> tuple: return tuple(v for v in self.to_list()) - def get_ort_inputs(self) -> Dict: + def get_ort_inputs(self) -> dict: return { "input_ids": numpy.ascontiguousarray(self.input_ids.cpu().numpy()), "attention_mask": numpy.ascontiguousarray(self.attention_mask.cpu().numpy()), @@ -69,7 +68,7 @@ def get_dummy_inputs( return LongformerInputs(input_ids, attention_mask, global_attention_mask) @staticmethod - def get_output_shapes(batch_size: int, sequence_length: int, hidden_size: int) -> Dict[str, List[int]]: + def get_output_shapes(batch_size: int, sequence_length: int, hidden_size: int) -> dict[str, list[int]]: """Returns a dictionary with output name as key, and shape as value.""" return { "last_state": [batch_size, sequence_length, hidden_size], diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py index f75a4527be57d..16d71d5057b02 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py @@ -11,8 +11,8 @@ import csv import statistics import time +from collections.abc import Mapping from datetime import datetime -from typing import List, Mapping, Optional import torch from image_decoder import SAM2ImageDecoder @@ -84,7 +84,7 @@ def __init__( def __repr__(self): return f"{vars(self)}" - def shape_dict(self) -> Mapping[str, List[int]]: + def shape_dict(self) -> Mapping[str, list[int]]: if self.component == "image_encoder": return encoder_shape_dict(self.batch_size, self.height, self.width) else: @@ -283,7 +283,7 @@ def run_torch(config: TestConfig): def run_test( args: argparse.Namespace, - csv_writer: Optional[csv.DictWriter] = None, + csv_writer: csv.DictWriter | None = None, ): use_gpu: bool = args.use_gpu enable_cuda_graph: bool = args.use_cuda_graph diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py index b9f30d0371dbe..c5ce339732063 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py @@ -75,7 +75,7 @@ def forward( feats = [ feat.permute(1, 2, 0).reshape(1, -1, *feat_size) - for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1]) + for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1], strict=False) ][::-1] if nvtx_helper is not None: diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py index af6b0e17e77f1..7f43724a6343f 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import os -from typing import Union import matplotlib.image as mpimg import matplotlib.pyplot as plt @@ -64,7 +63,7 @@ def show_masks( output_image_file_prefix=None, image_files=None, ): - for i, (mask, score) in enumerate(zip(masks, scores)): + for i, (mask, score) in enumerate(zip(masks, scores, strict=False)): plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(mask, plt.gca(), borders=borders) @@ -92,7 +91,7 @@ def show_masks( def get_predictor( sam2_dir: str, - device: Union[str, torch.device], + device: str | torch.device, dtype: torch.dtype, model_type="sam2_hiera_large", engine="torch", @@ -303,7 +302,7 @@ def run_demo( def show_all_images(left_images, right_images, suffix=""): # Show images in two rows since display screen is horizontal in most cases. fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80)) - for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images)): + for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images, strict=False)): left_img = mpimg.imread(left_img_path) right_img = mpimg.imread(right_img_path) diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py index 3c0c886b877f0..2f34bfa9aa09a 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import logging -from typing import Optional, Tuple, Union import numpy as np import torch @@ -41,7 +40,7 @@ def create_session( onnx_path: str, session_options=None, provider="CUDAExecutionProvider", - device: Union[str, torch.device] = "cuda", + device: str | torch.device = "cuda", enable_cuda_graph=False, ) -> CudaSession: ort_session = create_ort_session( @@ -59,7 +58,7 @@ def __init__( image_decoder_onnx_path: str = "", image_decoder_multi_onnx_path: str = "", provider: str = "CUDAExecutionProvider", - device: Union[str, torch.device] = "cuda", + device: str | torch.device = "cuda", onnx_dtype: torch.dtype = torch.float32, mask_threshold=0.0, max_hole_area=0.0, @@ -114,7 +113,7 @@ def __init__( ) @torch.no_grad() - def set_image(self, image: Union[np.ndarray, Image]): + def set_image(self, image: np.ndarray | Image): """ Calculates the image embeddings for the provided image. @@ -162,14 +161,14 @@ def set_image(self, image: Union[np.ndarray, Image]): @torch.no_grad() def _predict( self, - point_coords: Optional[torch.Tensor], - point_labels: Optional[torch.Tensor], - boxes: Optional[torch.Tensor] = None, - mask_input: Optional[torch.Tensor] = None, + point_coords: torch.Tensor | None, + point_labels: torch.Tensor | None, + boxes: torch.Tensor | None = None, + mask_input: torch.Tensor | None = None, multimask_output: bool = True, return_logits: bool = False, img_idx: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Predict masks for the given input prompts, using the currently set image. Input prompts are batched torch tensors and are expected to already be diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py index 4ec4ccc274291..d983cefaaaeec 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py @@ -5,7 +5,7 @@ import logging import os import sys -from typing import List, Mapping, Union +from collections.abc import Mapping import torch from sam2.build_sam import build_sam2 @@ -27,7 +27,7 @@ def _get_model_cfg(model_type) -> str: return model_cfg -def load_sam2_model(sam2_dir, model_type, device: Union[str, torch.device] = "cpu") -> SAM2Base: +def load_sam2_model(sam2_dir, model_type, device: str | torch.device = "cpu") -> SAM2Base: checkpoints_dir = os.path.join(sam2_dir, "checkpoints") sam2_config_dir = os.path.join(sam2_dir, "sam2_configs") if not os.path.exists(sam2_dir): @@ -65,7 +65,7 @@ def sam2_onnx_path(output_dir, model_type, component, multimask_output=False, su ) -def encoder_shape_dict(batch_size: int, height: int, width: int) -> Mapping[str, List[int]]: +def encoder_shape_dict(batch_size: int, height: int, width: int) -> Mapping[str, list[int]]: assert height == 1024 and width == 1024, "Only 1024x1024 images are supported." return { "image": [batch_size, 3, height, width], diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index a50940933eb82..30f4663100d8a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -23,7 +23,7 @@ import os import sys from importlib.metadata import PackageNotFoundError, version -from typing import Any, Dict, List, Optional +from typing import Any import controlnet_aux import cv2 @@ -307,7 +307,7 @@ def max_batch(args): return max_batch_size -def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]: +def get_metadata(args, is_xl: bool = False) -> dict[str, Any]: metadata = { "command": " ".join(['"' + x + '"' if " " in x else x for x in sys.argv]), "args.prompt": args.prompt, @@ -410,7 +410,7 @@ def initialize_pipeline( lora_scale: float = 1.0, use_fp16_vae: bool = True, use_vae: bool = True, - framework_model_dir: Optional[str] = None, + framework_model_dir: str | None = None, max_cuda_graphs: int = 1, ): pipeline_info = PipelineInfo( @@ -649,7 +649,7 @@ def get_canny_image(image) -> Image.Image: return image -def process_controlnet_images_xl(args) -> List[Image.Image]: +def process_controlnet_images_xl(args) -> list[Image.Image]: """ Process control image for SDXL control net. """ diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index c2cfc165e32cf..8dcda8a7633ac 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -24,7 +24,6 @@ import logging import os import tempfile -from typing import Dict, List, Optional import onnx import onnx_graphsurgeon as gs @@ -135,7 +134,7 @@ def is_xl_refiner(self) -> bool: def use_safetensors(self) -> bool: return self.is_xl() or self.version in ["sd-turbo"] - def stages(self) -> List[str]: + def stages(self) -> list[str]: if self.is_xl_base_or_turbo(): return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae else []) @@ -150,11 +149,11 @@ def vae_scaling_factor(self) -> float: def vae_torch_fallback(self) -> bool: return self.is_xl() and not self._use_fp16_vae - def custom_fp16_vae(self) -> Optional[str]: + def custom_fp16_vae(self) -> str | None: # For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None - def custom_unet(self) -> Optional[str]: + def custom_unet(self) -> str | None: return "latent-consistency/lcm-sdxl" if self._use_lcm and self.is_xl_base() else None @staticmethod @@ -372,13 +371,13 @@ def from_pretrained(self, model_class, framework_model_dir, subfolder=None, mode def load_model(self, framework_model_dir: str, subfolder: str): pass - def get_input_names(self) -> List[str]: + def get_input_names(self) -> list[str]: pass - def get_output_names(self) -> List[str]: + def get_output_names(self) -> list[str]: pass - def get_dynamic_axes(self) -> Dict[str, Dict[int, str]]: + def get_dynamic_axes(self) -> dict[str, dict[int, str]]: pass def get_sample_input(self, batch_size, image_height, image_width) -> tuple: @@ -418,7 +417,7 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, def get_shape_dict(self, batch_size, image_height, image_width): pass - def fp32_input_output_names(self) -> List[str]: + def fp32_input_output_names(self) -> list[str]: """For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model. This is a list of input or output names that are kept as float32 in optimized model. """ @@ -720,7 +719,7 @@ def __init__(self, unet, controlnets: ControlNetModel): def forward(self, sample, timestep, encoder_hidden_states, controlnet_images, controlnet_scales): for i, (controlnet_image, conditioning_scale, controlnet) in enumerate( - zip(controlnet_images, controlnet_scales, self.controlnets) + zip(controlnet_images, controlnet_scales, self.controlnets, strict=False) ): down_samples, mid_sample = controlnet( sample, @@ -739,7 +738,7 @@ def forward(self, sample, timestep, encoder_hidden_states, controlnet_images, co else: down_block_res_samples = [ samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=False) ] mid_block_res_sample += mid_sample @@ -772,7 +771,7 @@ def forward( ): added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} for i, (controlnet_image, conditioning_scale, controlnet) in enumerate( - zip(controlnet_images, controlnet_scales, self.controlnets) + zip(controlnet_images, controlnet_scales, self.controlnets, strict=False) ): down_samples, mid_sample = controlnet( sample, @@ -790,7 +789,7 @@ def forward( else: down_block_res_samples = [ samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=False) ] mid_block_res_sample += mid_sample @@ -1152,7 +1151,7 @@ def __init__( device, max_batch_size, fp16: bool = False, - custom_fp16_vae: Optional[str] = None, + custom_fp16_vae: str | None = None, ): super().__init__( pipeline_info, @@ -1232,7 +1231,7 @@ def get_sample_input(self, batch_size, image_height, image_width): dtype = torch.float16 if self.fp16 else torch.float32 return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=dtype, device=self.device),) - def fp32_input_output_names(self) -> List[str]: + def fp32_input_output_names(self) -> list[str]: return [] diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py index 41d2d267c5e11..ff23874000019 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -21,7 +21,6 @@ # limitations under the License. # -------------------------------------------------------------------------- -from typing import List, Optional import numpy as np import torch @@ -391,8 +390,8 @@ def __init__( predict_x0: bool = True, solver_type: str = "bh2", lower_order_final: bool = True, - disable_corrector: Optional[List[int]] = None, - use_karras_sigmas: Optional[bool] = False, + disable_corrector: list[int] | None = None, + use_karras_sigmas: bool | None = False, timestep_spacing: str = "linspace", steps_offset: int = 0, sigma_min=None, @@ -627,7 +626,7 @@ def multistep_uni_p_bh_update( model_output: torch.FloatTensor, *args, sample: torch.FloatTensor = None, - order: Optional[int] = None, + order: int | None = None, **kwargs, ) -> torch.FloatTensor: prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) @@ -734,7 +733,7 @@ def multistep_uni_c_bh_update( *args, last_sample: torch.FloatTensor = None, this_sample: torch.FloatTensor = None, - order: Optional[int] = None, + order: int | None = None, **kwargs, ) -> torch.FloatTensor: this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) @@ -1084,7 +1083,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, + generator: torch.Generator | None = None, ): if self.num_inference_steps is None: raise ValueError( diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index 7609ae10fc96d..d36411a1fa84d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -5,7 +5,6 @@ import hashlib import os from enum import Enum -from typing import Optional import torch from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL @@ -275,7 +274,7 @@ def vae_decode(self, latents): def get_engine_paths( - work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: Optional[str] = None + work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: str | None = None ): root_dir = work_dir or "." short_name = pipeline_info.short_name() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 56012e223b18c..040e3a38dbc52 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -6,7 +6,6 @@ import gc import logging import os -from typing import Dict, List, Optional import onnx import torch @@ -72,7 +71,7 @@ def metadata(self, name: str): data[f"{name}.gpu_graph_id"] = self.current_gpu_binding.last_run_gpu_graph_id return data - def infer(self, feed_dict: Dict[str, torch.Tensor]): + def infer(self, feed_dict: dict[str, torch.Tensor]): return self.current_gpu_binding.infer(feed_dict=feed_dict, disable_cuda_graph_in_run=not self.enable_cuda_graph) def allocate_buffers(self, shape_dict, device): @@ -93,7 +92,7 @@ def __init__( onnx_opset_version: int, use_cuda_graph: bool, fp16: bool = True, - force_fp32_ops: Optional[List[str]] = None, + force_fp32_ops: list[str] | None = None, optimize_by_ort: bool = True, ): self.onnx_opset_version = onnx_opset_version @@ -140,7 +139,7 @@ def _configure( onnx_opset_version: int, use_cuda_graph: bool, fp16: bool = True, - force_fp32_ops: Optional[List[str]] = None, + force_fp32_ops: list[str] | None = None, optimize_by_ort: bool = True, ): self.model_config[model_name] = _ModelConfig( @@ -238,11 +237,11 @@ def build_engines( engine_dir: str, framework_model_dir: str, onnx_dir: str, - tmp_dir: Optional[str] = None, + tmp_dir: str | None = None, onnx_opset_version: int = 17, device_id: int = 0, save_fp32_intermediate_model: bool = False, - import_engine_dir: Optional[str] = None, + import_engine_dir: str | None = None, max_cuda_graphs: int = 1, ): self.torch_device = torch.device("cuda", device_id) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 52d332848357f..24897756b2d7a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -21,7 +21,6 @@ import shutil import tempfile from pathlib import Path -from typing import List, Optional import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs @@ -101,11 +100,11 @@ def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, pipeline_type: str, - model_list: List[str], - use_external_data_format: Optional[bool], + model_list: list[str], + use_external_data_format: bool | None, float16: bool, bfloat16: bool, - force_fp32_ops: List[str], + force_fp32_ops: list[str], enable_runtime_optimization: bool, args, ): @@ -400,7 +399,7 @@ def _optimize_sd_pipeline( return op_counters -def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: List[str]): +def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: list[str]): """Copy extra directory that does not have onnx model Args: @@ -448,7 +447,7 @@ def optimize_stable_diffusion_pipeline( input_dir: str, output_dir: str, overwrite: bool, - use_external_data_format: Optional[bool], + use_external_data_format: bool | None, float16: bool, enable_runtime_optimization: bool, args, @@ -480,7 +479,7 @@ def optimize_stable_diffusion_pipeline( ) -def parse_arguments(argv: Optional[List[str]] = None): +def parse_arguments(argv: list[str] | None = None): """Parse arguments Returns: @@ -570,7 +569,7 @@ def parse_arguments(argv: Optional[List[str]] = None): return args -def main(argv: Optional[List[str]] = None): +def main(argv: list[str] | None = None): args = parse_arguments(argv) logger.info("Arguments: %s", str(args)) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index ac955f50141d2..e2f202e32221d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -24,7 +24,7 @@ import pathlib import random import time -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np import nvtx @@ -485,7 +485,7 @@ def decode_latent(self, latents): self.stop_profile("vae") return images - def print_summary(self, tic, toc, batch_size, vae_enc=False, pil=False) -> Dict[str, Any]: + def print_summary(self, tic, toc, batch_size, vae_enc=False, pil=False) -> dict[str, Any]: throughput = batch_size / (toc - tic) latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1] @@ -546,7 +546,7 @@ def pt_to_numpy(images: torch.FloatTensor): """ return ((images + 1) / 2).clamp(0, 1).detach().permute(0, 2, 3, 1).float().cpu().numpy() - def metadata(self) -> Dict[str, Any]: + def metadata(self) -> dict[str, Any]: data = { "actual_steps": self.actual_steps, "seed": self.get_current_seed(), @@ -561,7 +561,7 @@ def metadata(self) -> Dict[str, Any]: return data - def save_images(self, images: List, prompt: List[str], negative_prompt: List[str], metadata: Dict[str, Any]): + def save_images(self, images: list, prompt: list[str], negative_prompt: list[str], metadata: dict[str, Any]): session_id = str(random.randint(1000, 9999)) for i, image in enumerate(images): seed = str(self.get_current_seed()) @@ -747,17 +747,17 @@ def _infer( def run( self, - prompt: List[str], - negative_prompt: List[str], + prompt: list[str], + negative_prompt: list[str], image_height: int, image_width: int, denoising_steps: int = 30, guidance: float = 5.0, - seed: Optional[int] = None, - image: Optional[torch.Tensor] = None, + seed: int | None = None, + image: torch.Tensor | None = None, strength: float = 0.3, - controlnet_images: Optional[torch.Tensor] = None, - controlnet_scales: Optional[torch.Tensor] = None, + controlnet_images: torch.Tensor | None = None, + controlnet_scales: torch.Tensor | None = None, show_latency: bool = False, output_type: str = "pil", deterministic: bool = False, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py index 86477a7e3168b..ab3d3d8f58545 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py @@ -1,6 +1,5 @@ import argparse import os -from typing import Optional import cv2 import open_clip @@ -19,7 +18,7 @@ def arg_parser(): return args -def image_encoder(img: Image.Image, cache_dir: Optional[str] = None): # -> torch.Tensor: +def image_encoder(img: Image.Image, cache_dir: str | None = None): # -> torch.Tensor: device = "cuda" if torch.cuda.is_available() else "cpu" model, _, preprocess = open_clip.create_model_and_transforms( "ViT-B-16-plus-240", pretrained="laion400m_e32", cache_dir=cache_dir @@ -46,7 +45,7 @@ def load_image(image_path: str): # -> Image.Image: return img -def generate_score(image1: str, image2: str, cache_dir: Optional[str] = None): # -> float: +def generate_score(image1: str, image2: str, cache_dir: str | None = None): # -> float: test_img = load_image(image1) data_img = load_image(image2) img1 = image_encoder(test_img, cache_dir) diff --git a/onnxruntime/python/tools/transformers/models/t5/past_helper.py b/onnxruntime/python/tools/transformers/models/t5/past_helper.py index 915b09da79fe6..0f72a89498dad 100644 --- a/onnxruntime/python/tools/transformers/models/t5/past_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/past_helper.py @@ -5,7 +5,6 @@ # -------------------------------------------------------------------------- import logging -from typing import List, Tuple import torch @@ -71,7 +70,7 @@ def group_by_layer(past, num_layers): ) @staticmethod - def back_group_by_layer(past_key_values: Tuple[Tuple[torch.Tensor]]): + def back_group_by_layer(past_key_values: tuple[tuple[torch.Tensor]]): """Categorize present_key_values from self and cross attention to layer by layer. Reorder past state from grouped by self/cross attention to grouped by layer. @@ -101,7 +100,7 @@ def back_group_by_layer(past_key_values: Tuple[Tuple[torch.Tensor]]): return past_tuples @staticmethod - def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: bool = False): + def group_by_self_and_cross(present_key_values: tuple[torch.Tensor], concat: bool = False): """Categorize present_key_values into self and cross attention. Split present state from grouped by layer to grouped by self/cross attention. @@ -118,8 +117,8 @@ def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: boo present_self (Tuple[torch.Tensor]): present key and values from self attention present_cross (Tuple[torch.Tensor]): present key and values from cross attention """ - present_self: List[torch.Tensor] = [] - present_cross: List[torch.Tensor] = [] + present_self: list[torch.Tensor] = [] + present_cross: list[torch.Tensor] = [] for _, present_layer_i in enumerate(present_key_values): assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}" present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i @@ -131,7 +130,7 @@ def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: boo return present_self, present_cross @staticmethod - def get_input_names(past_key_values: Tuple[Tuple[torch.Tensor]], encoder=True): + def get_input_names(past_key_values: tuple[tuple[torch.Tensor]], encoder=True): """Process input names of model wrapper. Args: diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py index 19e6bba22dc1a..a93c1705b2cd9 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py @@ -8,7 +8,6 @@ import os import tempfile from pathlib import Path -from typing import List, Optional, Union import numpy import onnx @@ -34,8 +33,8 @@ def __init__( self, decoder: torch.nn.Module, lm_head: torch.nn.Module, - config: Union[T5Config, MT5Config], - decoder_start_token_id: Optional[int] = None, + config: T5Config | MT5Config, + decoder_start_token_id: int | None = None, ): super().__init__() self.decoder = decoder @@ -133,11 +132,11 @@ def __init__( ): self.decoder_input_ids: torch.LongTensor = decoder_input_ids self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask - self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values + self.past_key_values: list[torch.FloatTensor] | list[torch.HalfTensor] | None = past_key_values @staticmethod def create_dummy( - config: Union[T5Config, MT5Config], + config: T5Config | MT5Config, batch_size: int, encode_sequence_length: int, past_decode_sequence_length: int, @@ -211,7 +210,7 @@ def create_dummy( return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past) - def to_list(self) -> List: + def to_list(self) -> list: input_list = [ self.decoder_input_ids, self.encoder_attention_mask, @@ -232,7 +231,7 @@ def to_fp32(self): class T5DecoderHelper: @staticmethod def export_onnx( - decoder: Union[T5Decoder, T5DecoderInit], + decoder: T5Decoder | T5DecoderInit, device: torch.device, onnx_model_path: str, verbose: bool = True, @@ -370,7 +369,7 @@ def onnxruntime_inference(ort_session, inputs: T5DecoderInputs): @staticmethod def verify_onnx( - model: Union[T5Decoder, T5DecoderInit], + model: T5Decoder | T5DecoderInit, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool, diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py index fb61e970c1e0c..c6b0f7ee3adc2 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py @@ -9,7 +9,6 @@ import random import tempfile from pathlib import Path -from typing import List, Union import numpy import onnx @@ -26,7 +25,7 @@ class T5Encoder(torch.nn.Module): """T5 encoder outputs only the last hidden state""" - def __init__(self, encoder, config: Union[T5Config, MT5Config]): + def __init__(self, encoder, config: T5Config | MT5Config): super().__init__() self.encoder = encoder self.config = config @@ -72,7 +71,7 @@ def create_dummy( attention_mask[i, :padding_position] = 0 return T5EncoderInputs(input_ids, attention_mask) - def to_list(self) -> List: + def to_list(self) -> list: input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None] return input_list diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py index fd6ea45ef8b7c..c76d7aabdf11a 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py @@ -8,7 +8,6 @@ import os import tempfile from pathlib import Path -from typing import List, Optional, Union import numpy import onnx @@ -33,8 +32,8 @@ def __init__( encoder: torch.nn.Module, decoder: torch.nn.Module, lm_head: torch.nn.Module, - config: Union[T5Config, MT5Config], - decoder_start_token_id: Optional[int] = None, + config: T5Config | MT5Config, + decoder_start_token_id: int | None = None, ): super().__init__() self.config = config @@ -62,7 +61,7 @@ def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids= @staticmethod def create_dummy( - config: Union[T5Config, MT5Config], + config: T5Config | MT5Config, batch_size: int, encode_sequence_length: int, use_decoder_input_ids: int, @@ -83,7 +82,7 @@ def create_dummy( return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids) - def to_list(self) -> List: + def to_list(self) -> list: input_list = [self.encoder_input_ids, self.encoder_attention_mask] if self.decoder_input_ids is not None: input_list.append(self.decoder_input_ids) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py index f7dc9db0e82c8..d3f25e979887d 100755 --- a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py @@ -7,7 +7,6 @@ import logging import os from pathlib import Path -from typing import Dict, List, Union import torch from float16 import float_to_float16_max_diff @@ -64,7 +63,7 @@ def load_model( merge_encoder_and_decoder_init: bool = True, model_type: str = "t5", state_dict_path: str = "", - ) -> Dict[str, torch.nn.Module]: + ) -> dict[str, torch.nn.Module]: """Load model given a pretrained name or path, then build models for ONNX conversion. Args: @@ -111,7 +110,7 @@ def load_model( @staticmethod def export_onnx( - model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit], + model: T5Encoder | T5Decoder | T5DecoderInit | T5EncoderDecoderInit, device: torch.device, onnx_model_path: str, verbose: bool = True, @@ -151,7 +150,7 @@ def export_onnx( @staticmethod def auto_mixed_precision( onnx_model: OnnxModel, - op_block_list: List[str] = [ # noqa: B006 + op_block_list: list[str] = [ # noqa: B006 "SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", @@ -257,7 +256,7 @@ def optimize_onnx( @staticmethod def verify_onnx( - model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit], + model: T5Encoder | T5Decoder | T5DecoderInit | T5EncoderDecoderInit, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 87ac45101f0c0..feb688948d8f5 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -24,7 +24,7 @@ def verify_inputs(beam_inputs, graph_inputs): # Verify that ONNX graph's inputs match beam search op's inputs beam_required_inputs = list(filter(lambda beam_input: beam_input, beam_inputs)) assert len(graph_inputs) == len(beam_required_inputs) - for graph_input, beam_input in zip(graph_inputs, beam_required_inputs): + for graph_input, beam_input in zip(graph_inputs, beam_required_inputs, strict=False): # Check if graph_input is in beam_input to handle beam_input names with the "_fp16" suffix assert graph_input.name in beam_input diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 5da235d72ca0b..400cafc4c93c3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -8,7 +8,6 @@ import os import tempfile from pathlib import Path -from typing import List, Optional, Union import numpy import onnx @@ -34,7 +33,7 @@ def __init__( self, decoder: torch.nn.Module, config: WhisperConfig, - decoder_start_token_id: Optional[int] = None, + decoder_start_token_id: int | None = None, ): super().__init__() self.decoder = decoder @@ -115,7 +114,7 @@ def __init__( past_key_values=None, ): self.decoder_input_ids: torch.LongTensor = decoder_input_ids - self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values + self.past_key_values: list[torch.FloatTensor] | list[torch.HalfTensor] | None = past_key_values @staticmethod def create_dummy( @@ -186,7 +185,7 @@ def create_dummy( return WhisperDecoderInputs(decoder_input_ids, past) - def to_list(self) -> List: + def to_list(self) -> list: input_list = [self.decoder_input_ids] if self.past_key_values: input_list.extend(self.past_key_values) @@ -333,7 +332,7 @@ def onnxruntime_inference(ort_session, inputs: WhisperDecoderInputs): @staticmethod def verify_onnx( - model: Union[WhisperDecoder, WhisperDecoderInit], + model: WhisperDecoder | WhisperDecoderInit, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py index 93281848a5c9c..0b9db81486caa 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py @@ -8,7 +8,6 @@ import os import tempfile from pathlib import Path -from typing import List import numpy import onnx @@ -67,7 +66,7 @@ def create_dummy( ) return WhisperEncoderInputs(input_features) - def to_list(self) -> List: + def to_list(self) -> list: if self.input_ids is None: return [] return [self.input_ids] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index fab2a2aa4c8a8..c7c7a7675c1a7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -8,7 +8,6 @@ import os import tempfile from pathlib import Path -from typing import List, Optional import numpy import onnx @@ -34,7 +33,7 @@ def __init__( encoder: torch.nn.Module, decoder: torch.nn.Module, config: WhisperConfig, - decoder_start_token_id: Optional[int] = None, + decoder_start_token_id: int | None = None, model_impl: str = "hf", model: torch.nn.Module = None, ): @@ -94,7 +93,7 @@ def create_dummy( return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids) - def to_list(self) -> List: + def to_list(self) -> list: input_list = [self.encoder_input_ids] if self.decoder_input_ids is not None: input_list.append(self.decoder_input_ids) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 9fb51dd9b43c0..38003c2693296 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -7,7 +7,6 @@ import logging import os from pathlib import Path -from typing import Dict, Tuple, Union import numpy as np import torch @@ -117,7 +116,7 @@ def load_model( device: torch.device, merge_encoder_and_decoder_init: bool = True, state_dict_path: str = "", - ) -> Dict[str, torch.nn.Module]: + ) -> dict[str, torch.nn.Module]: """Load model given a pretrained name or path, then build models for ONNX conversion. Args: @@ -170,7 +169,7 @@ def load_model( @staticmethod def export_onnx( - model: Union[WhisperEncoder, WhisperDecoder, WhisperDecoderInit, WhisperEncoderDecoderInit], + model: WhisperEncoder | WhisperDecoder | WhisperDecoderInit | WhisperEncoderDecoderInit, device: torch.device, onnx_model_path: str, verbose: bool = True, @@ -209,7 +208,7 @@ def export_onnx( @staticmethod def auto_mixed_precision( onnx_model: OnnxModel, - op_block_list: Tuple[str] = ( + op_block_list: tuple[str] = ( "SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", @@ -460,7 +459,7 @@ def verify_onnx( } use_extra_decoding_ids = "extra_decoding_ids" in ort_names - for name, dtype in zip(ort_names, ort_dtypes): + for name, dtype in zip(ort_names, ort_dtypes, strict=False): if name == "input_features": inputs[name] = inputs[name].detach().cpu().numpy() elif name == "vocab_mask": diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 33506d6d00cac..ef80d36be3b18 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -9,7 +9,6 @@ import sys from collections import deque from pathlib import Path -from typing import Dict, List, Optional, Tuple from float16 import convert_float_to_float16 from onnx import ( @@ -35,16 +34,16 @@ def __init__(self, model): def initialize(self, model): self.model: ModelProto = model - self._node_name_suffix: Dict[str, int] = {} # key is node name prefix, value is the last suffix generated + self._node_name_suffix: dict[str, int] = {} # key is node name prefix, value is the last suffix generated self.shape_infer_helper: SymbolicShapeInferenceHelper = None self.enable_shape_infer: bool = True - self.all_graphs: Optional[List[GraphProto]] = None + self.all_graphs: list[GraphProto] | None = None # Cache of shape and data type from onnx graph to speed up optimization. # Be careful that fusion shall not reuse node output name for different shape/type (in adding/removing nodes) # Note that these do not cache the symbolic shape inference result. - self._dtype_dict: Optional[Dict[str, int]] = None - self._shape_dict: Optional[Dict[str, List]] = None + self._dtype_dict: dict[str, int] | None = None + self._shape_dict: dict[str, list] | None = None def disable_shape_inference(self): self.enable_shape_infer = False @@ -348,7 +347,7 @@ def match_parent( def match_parent_paths(self, node, paths, output_name_to_node): for i, path in enumerate(paths): - assert isinstance(path, (List, Tuple)) + assert isinstance(path, (list, tuple)) return_indice = [] matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) if matched: @@ -358,7 +357,7 @@ def match_parent_paths(self, node, paths, output_name_to_node): def match_parent_paths_all(self, node, paths, output_name_to_node): match_i, matches, return_indices = [], [], [] for i, path in enumerate(paths): - assert isinstance(path, (List, Tuple)) + assert isinstance(path, (list, tuple)) return_indice = [] matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) if matched: @@ -442,7 +441,7 @@ def match_child_path( self, node, child_op_types, - edges: Optional[List[Tuple[int, int]]] = None, + edges: list[tuple[int, int]] | None = None, input_name_to_nodes=None, exclude=[], # noqa: B006 ): @@ -600,7 +599,7 @@ def tensor_shape_to_list(self, tensor_type): shape_list.append("?") # shall not happen return shape_list - def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + def get_dtype(self, name: str, symbolic_shape_helper: SymbolicShapeInferenceHelper | None = None): """Try get data type given a name (could be initializer, input or output of graph or node).""" if self._dtype_dict is None: @@ -625,7 +624,7 @@ def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInfe return None - def get_shape(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + def get_shape(self, name: str, symbolic_shape_helper: SymbolicShapeInferenceHelper | None = None): """Try get shape given a name (could be initializer, input or output of graph or node).""" if self._shape_dict is None: @@ -1320,8 +1319,8 @@ def to_data_hash(tensor: TensorProto, base_dir: str = "") -> int: def has_same_value( tensor1: TensorProto, tensor2: TensorProto, - signature_cache1: Optional[dict] = None, - signature_cache2: Optional[dict] = None, + signature_cache1: dict | None = None, + signature_cache2: dict | None = None, ) -> bool: """Returns True when two tensors have same value. Note that name can be different. @@ -1354,7 +1353,7 @@ def has_same_value( return False - def remove_duplicated_initializer(self, cache: Optional[dict] = None): + def remove_duplicated_initializer(self, cache: dict | None = None): """Remove initializers with duplicated values, and only keep the first one. It could help reduce size of models (like ALBert) with shared weights. If require_raw_data passed, method will only compare raw_data initializers to speed runtime diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 61a786d7af60b..496146dbf8cb5 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Optional from fusion_attention import AttentionMask from fusion_bart_attention import FusionBartAttention @@ -127,7 +126,7 @@ def __init__(self, model, num_heads, hidden_size, model_impl="hf"): self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) self.bart_reshape_fusion_preprocess = FusionBartReshape(self) - def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False): self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention self.attention_fusion.disable_multi_head_attention_bias = ( False if options is None else options.disable_multi_head_attention_bias diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 26464fc32817d..c4e8b64fd8486 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import List, Optional from convert_to_packing_mode import PackingMode from fusion_attention import AttentionMask, FusionAttention @@ -147,7 +146,7 @@ def fuse_qordered_mamtul(self): fusion = FusionQOrderedMatMul(self) fusion.apply() - def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool): + def get_graph_inputs_from_node_type(self, op_type: str, input_indices: list[int], casted: bool): """ Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention). Returns a list of the graph input names based on the filter whether it is casted or not. @@ -323,7 +322,7 @@ def postprocess(self): self.clean_graph() self.prune_graph() - def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() diff --git a/onnxruntime/python/tools/transformers/onnx_model_conformer.py b/onnxruntime/python/tools/transformers/onnx_model_conformer.py index 1506d85f53fd4..65723aabc2e18 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_conformer.py +++ b/onnxruntime/python/tools/transformers/onnx_model_conformer.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Optional from fusion_attention import AttentionMask from fusion_conformer_attention import FusionConformerAttention @@ -19,7 +18,7 @@ def __init__(self, model, num_heads, hidden_size): self.attention_mask = AttentionMask(self) self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask) - def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False): self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention self.attention_fusion.disable_multi_head_attention_bias = ( False if options is None else options.disable_multi_head_attention_bias diff --git a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py index 4c9b19c0c97ca..35a574129e78c 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_mmdit.py +++ b/onnxruntime/python/tools/transformers/onnx_model_mmdit.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import logging -from typing import Optional from fusion_layernorm import FusionLayerNormalization from fusion_mha_mmdit import FusionMultiHeadAttentionMMDit @@ -47,7 +46,7 @@ def fuse_multi_head_attention(self): fusion = FusionMultiHeadAttentionMMDit(self) fusion.apply() - def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False): assert not add_dynamic_axes if is_installed("tqdm"): @@ -62,7 +61,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo logger.info("tqdm is not installed. Run optimization without progress bar") self._optimize(options, None) - def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): + def _optimize(self, options: FusionOptions | None = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py index 5df765033578b..d2f10d0bc18af 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_phi.py +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import List, Optional import numpy as np from dynamo_onnx_helper import DynamoOnnxHelper @@ -70,7 +69,7 @@ class Fission(Fusion): def __init__( self, model: OnnxModel, - nodes_to_find: List[str], + nodes_to_find: list[str], ): super().__init__(model, "DONOTUSE", nodes_to_find) @@ -129,7 +128,7 @@ def replace_fp32_value_info(self, name, shape): self.model.graph().value_info.extend([new_value_info]) def set_unique_name_and_add_nodes( - self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str] + self, subgraph_nodes: list[NodeProto], layer_id: int, layer_known_edges_names: list[str] ): for new_node in subgraph_nodes: for i, name in enumerate(new_node.input): @@ -148,7 +147,7 @@ def set_unique_name_and_add_nodes( self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name - def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + def layernorm(self, inputs: list[str], outputs: list[str], prefix: str = ""): assert len(inputs) == 3 assert len(outputs) == 1 node = helper.make_node( @@ -160,7 +159,7 @@ def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""): ) return [node] - def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + def gemm(self, inputs: list[str], outputs: list[str], prefix: str = ""): assert len(inputs) == 3 assert len(outputs) == 1 matmul = helper.make_node( @@ -177,7 +176,7 @@ def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""): ) return [matmul, add] - def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32): + def rotary(self, inputs: list[str], outputs: list[str], prefix: str = "", rot_dim=32, num_heads=32): assert len(inputs) == 4 assert len(outputs) == 1 node = helper.make_node( @@ -191,7 +190,7 @@ def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_di ) return [node] - def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""): + def fastgelu(self, inputs: list[str], outputs: list[str], prefix: str = ""): assert len(inputs) == 1 assert len(outputs) == 1 node = helper.make_node( @@ -203,7 +202,7 @@ def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""): ) return [node] - def add(self, inputs: List[str], outputs: List[str], prefix: str = ""): + def add(self, inputs: list[str], outputs: list[str], prefix: str = ""): assert len(inputs) == 2 assert len(outputs) == 1 node = helper.make_node( @@ -214,7 +213,7 @@ def add(self, inputs: List[str], outputs: List[str], prefix: str = ""): ) return [node] - def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + def mha(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32): assert len(inputs) == 8 assert len(outputs) == 3 node = helper.make_node( @@ -228,7 +227,7 @@ def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads ) return [node] - def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + def gqa(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32): assert len(inputs) == 7 assert len(outputs) == 3 node = helper.make_node( @@ -242,7 +241,7 @@ def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads ) return [node] - def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + def attention(self, inputs: list[str], outputs: list[str], prefix: str = "", num_heads=32): assert len(inputs) == 5 assert len(outputs) == 2 node = helper.make_node( @@ -260,8 +259,8 @@ def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num def paged_attn( self, - inputs: List[str], - outputs: List[str], + inputs: list[str], + outputs: list[str], prefix: str = "", num_heads=32, head_size=80, @@ -853,7 +852,7 @@ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self) self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self) - def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False): assert options is not None attn_op_type = options.attention_op_type diff --git a/onnxruntime/python/tools/transformers/onnx_model_sam2.py b/onnxruntime/python/tools/transformers/onnx_model_sam2.py index ac608fb509a81..9d57081c4ce12 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_sam2.py +++ b/onnxruntime/python/tools/transformers/onnx_model_sam2.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import logging -from typing import Optional from fusion_attention_sam2 import FusionMultiHeadAttentionSam2 from fusion_layernorm import FusionLayerNormalizationNCHW @@ -39,11 +38,11 @@ def fuse_layer_norm(self): fusion = FusionLayerNormalizationNCHW(self) fusion.apply() - def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): + def fuse_multi_head_attention(self, options: FusionOptions | None = None): mha_fusion = FusionMultiHeadAttentionSam2(self, self.hidden_size, self.num_heads) mha_fusion.apply() - def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False): if is_installed("tqdm"): import tqdm from tqdm.contrib.logging import logging_redirect_tqdm @@ -56,7 +55,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo logger.info("tqdm is not installed. Run optimization without progress bar") self._optimize(options, None) - def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): + def _optimize(self, options: FusionOptions | None = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 70742bb5f52e3..33dcc7795a465 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Optional, Union import numpy as np from fusion_attention import AttentionMask, FusionAttention @@ -50,8 +49,8 @@ def create_attention_node( input: str, output: str, add_qk_str: str, - scale: Optional[float] = None, - ) -> Union[NodeProto, None]: + scale: float | None = None, + ) -> NodeProto | None: """Create an Attention node. Args: mask_index (str): mask input @@ -163,7 +162,7 @@ def create_mha_node( present_value: str, num_heads: int, hidden_size: int, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: assert num_heads > 0 if hidden_size > 0 and (hidden_size % num_heads) != 0: diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py index f5a47b19d67fc..125aa47a7dbed 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Union from fusion_attention import AttentionMask, FusionAttention from fusion_utils import NumpyHelper @@ -39,7 +38,7 @@ def create_attention_node( input: str, output: str, add_qk_str: str, - ) -> Union[NodeProto, None]: + ) -> NodeProto | None: assert num_heads > 0 if hidden_size > 0 and (hidden_size % num_heads) != 0: logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 77e24986f0fde..e96cf32927171 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import logging -from typing import Optional from fusion_attention_unet import FusionAttentionUnet from fusion_bias_add import FusionBiasAdd @@ -91,7 +90,7 @@ def merge_adjacent_transpose(self): if total: logger.info("Removed %d Transpose nodes", total) - def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): + def fuse_multi_head_attention(self, options: FusionOptions | None = None): # Self Attention enable_packed_qkv = (options is None) or options.enable_packed_qkv self_attention_fusion = FusionAttentionUnet( @@ -120,7 +119,7 @@ def fuse_bias_add(self): fusion = FusionBiasAdd(self) fusion.apply() - def optimize(self, options: Optional[FusionOptions] = None): + def optimize(self, options: FusionOptions | None = None): if is_installed("tqdm"): import tqdm from tqdm.contrib.logging import logging_redirect_tqdm @@ -133,7 +132,7 @@ def optimize(self, options: Optional[FusionOptions] = None): logger.info("tqdm is not installed. Run optimization without progress bar") self._optimize(options, None) - def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): + def _optimize(self, options: FusionOptions | None = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() diff --git a/onnxruntime/python/tools/transformers/onnx_model_vae.py b/onnxruntime/python/tools/transformers/onnx_model_vae.py index de8b59074a871..1e531bbc3eff3 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_vae.py +++ b/onnxruntime/python/tools/transformers/onnx_model_vae.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import getLogger -from typing import Optional from fusion_attention_vae import FusionAttentionVae from fusion_options import FusionOptions @@ -19,7 +18,7 @@ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) - def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): + def fuse_multi_head_attention(self, options: FusionOptions | None = None): # Self Attention self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads) self_attention_fusion.apply() diff --git a/onnxruntime/python/tools/transformers/onnx_utils.py b/onnxruntime/python/tools/transformers/onnx_utils.py index 64fade9369395..7f681d783cb64 100644 --- a/onnxruntime/python/tools/transformers/onnx_utils.py +++ b/onnxruntime/python/tools/transformers/onnx_utils.py @@ -35,7 +35,7 @@ def extract_raw_data_from_model(model: ModelProto): initializer.name = name initializer.ClearField("raw_data") - return zip(*external_data) + return zip(*external_data, strict=False) def has_external_data(model: ModelProto): diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index a83c54e345d7d..c4d187e8bf031 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -22,7 +22,6 @@ import os import tempfile from pathlib import Path -from typing import Dict, List, Optional, Union import coloredlogs from fusion_options import FusionOptions @@ -72,17 +71,17 @@ def optimize_by_onnxruntime( - onnx_model: Optional[Union[str, ModelProto]] = None, + onnx_model: str | ModelProto | None = None, use_gpu: bool = False, - optimized_model_path: Optional[str] = None, - opt_level: Optional[int] = 99, - disabled_optimizers: List[str] = [], # noqa: B006 + optimized_model_path: str | None = None, + opt_level: int | None = 99, + disabled_optimizers: list[str] = [], # noqa: B006 verbose: bool = False, save_as_external_data: bool = False, external_data_filename: str = "", external_data_file_threshold: int = 1024, *, - provider: Optional[str] = None, + provider: str | None = None, **deprecated_kwargs, ) -> str: """ @@ -217,7 +216,7 @@ def optimize_by_fusion( model_type: str = "bert", num_heads: int = 0, hidden_size: int = 0, - optimization_options: Optional[FusionOptions] = None, + optimization_options: FusionOptions | None = None, ) -> OnnxModel: """Optimize Model by graph fusion logic. @@ -274,17 +273,17 @@ def optimize_by_fusion( def optimize_model( - input: Union[str, ModelProto], + input: str | ModelProto, model_type: str = "bert", num_heads: int = 0, hidden_size: int = 0, - optimization_options: Optional[FusionOptions] = None, - opt_level: Optional[int] = None, + optimization_options: FusionOptions | None = None, + opt_level: int | None = None, use_gpu: bool = False, only_onnxruntime: bool = False, verbose: bool = False, *, - provider: Optional[str] = None, + provider: str | None = None, ) -> OnnxModel: """Optimize Model by OnnxRuntime and/or python fusion logic. @@ -414,7 +413,7 @@ def optimize_model( return optimizer -def get_fusion_statistics(optimized_model_path: str) -> Dict[str, int]: +def get_fusion_statistics(optimized_model_path: str) -> dict[str, int]: """ Get counter of fused operators in optimized model. diff --git a/onnxruntime/python/tools/transformers/shape_infer_helper.py b/onnxruntime/python/tools/transformers/shape_infer_helper.py index f1fc0c952e8e4..f4d65d05ad0c8 100644 --- a/onnxruntime/python/tools/transformers/shape_infer_helper.py +++ b/onnxruntime/python/tools/transformers/shape_infer_helper.py @@ -6,7 +6,6 @@ import logging import os import sys -from typing import Dict # In ORT Package the symbolic_shape_infer.py is in ../tools file_path = os.path.dirname(__file__) @@ -26,9 +25,9 @@ def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_o self.model_ = model self.all_shapes_inferred_: bool = False self.is_inferred_: bool = False - self.dynamic_axis_mapping_: Dict[str, int] = {} + self.dynamic_axis_mapping_: dict[str, int] = {} - def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 200): + def infer(self, dynamic_axis_mapping: dict[str, int], max_runs: int = 200): """Run shape inference, and try replace dynamic axis from string to integer when mapping is provided. Args: diff --git a/onnxruntime/python/tools/transformers/shape_optimizer.py b/onnxruntime/python/tools/transformers/shape_optimizer.py index 17fd54f19baf2..9f590dfb86911 100644 --- a/onnxruntime/python/tools/transformers/shape_optimizer.py +++ b/onnxruntime/python/tools/transformers/shape_optimizer.py @@ -16,7 +16,6 @@ from collections import deque # noqa: F401 from datetime import datetime from pathlib import Path # noqa: F401 -from typing import List, Optional import numpy as np import onnx @@ -271,7 +270,7 @@ def validate_input(self, input: str): valid_names = [input.name for input in self.model.graph.input] raise Exception(f"Input {input} does not exist in the graph inputs: {valid_names}") - def validate_outputs(self, output_names: List[str]): + def validate_outputs(self, output_names: list[str]): valid_names = [output.name for output in self.model.graph.output] for name in output_names: if name not in valid_names: @@ -285,7 +284,7 @@ def optimize( input_mask: str, enable_shape_opt: bool, enable_reshape_opt: bool, - output_names: Optional[List[str]] = None, + output_names: list[str] | None = None, batch_size=1, sequence_length=128, verbose=False, diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py index bdb0ffc6c50db..52ce2ef5fdef1 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py @@ -7,7 +7,6 @@ # CUBLAS_WORKSPACE_CONFIG=:4096:8 python multihead_attention_op_test_data_gen.py import math -from typing import Optional, Tuple import numpy as np import torch @@ -56,12 +55,12 @@ def get_extended_attention_mask(self, attention_mask: Tensor, dtype: torch.dtype def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: + attention_mask: torch.FloatTensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.FloatTensor | None = None, + past_key_value: tuple[tuple[torch.FloatTensor]] | None = None, + output_attentions: bool | None = False, + ) -> tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) if self.verbose: print("q", mixed_query_layer) diff --git a/onnxruntime/test/providers/cpu/rnn/LSTM.py b/onnxruntime/test/providers/cpu/rnn/LSTM.py index 49e28a93385a4..472fa5f844ac0 100644 --- a/onnxruntime/test/providers/cpu/rnn/LSTM.py +++ b/onnxruntime/test/providers/cpu/rnn/LSTM.py @@ -2,13 +2,7 @@ # Licensed under the MIT License. -from typing import Any, Tuple # noqa: F401 - -import numpy as np # type: ignore - -# import onnx -# from ..base import Base -# from . import expect +import numpy as np DebugOutput = True np.set_printoptions(suppress=True) # , precision=16, floatmode='maxprec') diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py index 796a58f1a929c..5276b70789db1 100644 --- a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py @@ -24,7 +24,7 @@ test_count = 0 for align_corners in align_corners_options: - for angle, translation, scale in zip(angles, translations, scales): + for angle, translation, scale in zip(angles, translations, scales, strict=False): for size in sizes: theta = np.array([], dtype=np.float32) for _ in range(size[0]): @@ -71,7 +71,7 @@ test_count = 0 for align_corners in align_corners_options: - for angle, translation, scale in zip(angles, translations, scales): + for angle, translation, scale in zip(angles, translations, scales, strict=False): for size in sizes: theta = np.array([], dtype=np.float32) for _ in range(size[0]): diff --git a/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py b/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py index dd5d5cc90e0bf..1459dfc61c84c 100644 --- a/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py +++ b/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py @@ -65,11 +65,11 @@ def expect( del kwargs["output_types"] inputs_vi = [ _extract_value_info(arr, arr_name, input_type) - for arr, arr_name, input_type in zip(inputs, present_inputs, input_types) + for arr, arr_name, input_type in zip(inputs, present_inputs, input_types, strict=False) ] outputs_vi = [ _extract_value_info(arr, arr_name, output_type) - for arr, arr_name, output_type in zip(outputs, present_outputs, output_types) + for arr, arr_name, output_type in zip(outputs, present_outputs, output_types, strict=False) ] graph = onnx.helper.make_graph(nodes=[node], name=name, inputs=inputs_vi, outputs=outputs_vi) diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 8fc76da3495a8..23f6d3e23e9bf 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -9,7 +9,6 @@ import re import sys import unittest -from typing import Dict import numpy as np import onnx @@ -28,8 +27,8 @@ class OrtBackendTest(onnx.backend.test.runner.Runner): # pylint: disable=too-few-public-methods def __init__( self, - rtol_overrides: Dict[str, float], - atol_overrides: Dict[str, float], + rtol_overrides: dict[str, float], + atol_overrides: dict[str, float], ): self._rtol_overrides = rtol_overrides self._atol_overrides = atol_overrides diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index de70478761f19..7f4f4b5bb2270 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import unittest -from typing import Tuple import numpy as np import onnxscript @@ -23,7 +22,7 @@ def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): if axis is None: return X shards = np.split(X, len(device_mesh), axis) - selected_shards = tuple(shard for device_id, shard in zip(device_mesh, shards) if device_id == rank) + selected_shards = tuple(shard for device_id, shard in zip(device_mesh, shards, strict=False) if device_id == rank) return np.concatenate(selected_shards, axis=axis) @@ -99,12 +98,12 @@ def shard_tensor_per_spec(tensor: np.ndarray, rank: int, spec: str, device_mesh: class TestDistributedReshape(unittest.TestCase): def _check_distributed_reshape( self, - shape: Tuple[int, ...], - target_shape: Tuple[int, ...], + shape: tuple[int, ...], + target_shape: tuple[int, ...], input_device_meshes: np.ndarray, - input_shard_specs: Tuple[str, ...], + input_shard_specs: tuple[str, ...], output_device_meshes: np.ndarray, - output_shard_specs: Tuple[str, ...], + output_shard_specs: tuple[str, ...], ): input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @@ -683,12 +682,12 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self) class TestDistributedExpand(unittest.TestCase): def _check_distributed_expand( self, - shape: Tuple[int, ...], - target_shape: Tuple[int, ...], + shape: tuple[int, ...], + target_shape: tuple[int, ...], input_device_meshes: np.ndarray, - input_shard_specs: Tuple[str, ...], + input_shard_specs: tuple[str, ...], output_device_meshes: np.ndarray, - output_shard_specs: Tuple[str, ...], + output_shard_specs: tuple[str, ...], ): assert len(input_device_meshes) == len(input_shard_specs) assert len(output_device_meshes) == len(output_shard_specs) @@ -855,12 +854,12 @@ def test_expand_in_tiny_llama(self): class TestDistributedUnsqueeze(unittest.TestCase): def _check_distributed_unsqueeze( self, - shape: Tuple[int, ...], - axes: Tuple[int, ...], + shape: tuple[int, ...], + axes: tuple[int, ...], input_device_meshes: np.ndarray, - input_shard_specs: Tuple[str, ...], + input_shard_specs: tuple[str, ...], output_device_meshes: np.ndarray, - output_shard_specs: Tuple[str, ...], + output_shard_specs: tuple[str, ...], ): assert len(input_device_meshes) == len(input_shard_specs) assert len(output_device_meshes) == len(output_shard_specs) @@ -977,12 +976,12 @@ def test_unsqueeze_not_sharded(self): class TestDistributedSqueeze(unittest.TestCase): def _check_distributed_squeeze( self, - shape: Tuple[int, ...], - axes: Tuple[int, ...], + shape: tuple[int, ...], + axes: tuple[int, ...], input_device_meshes: np.ndarray, - input_shard_specs: Tuple[str, ...], + input_shard_specs: tuple[str, ...], output_device_meshes: np.ndarray, - output_shard_specs: Tuple[str, ...], + output_shard_specs: tuple[str, ...], ): assert len(input_device_meshes) == len(input_shard_specs) assert len(output_device_meshes) == len(output_shard_specs) @@ -1086,12 +1085,12 @@ def _check_distributed_reduce( self, keepdims: int, dtype: np.dtype, - shape: Tuple[int, ...], - axes: Tuple[int, ...], + shape: tuple[int, ...], + axes: tuple[int, ...], input_device_meshes: np.ndarray, - input_shard_specs: Tuple[str, ...], + input_shard_specs: tuple[str, ...], output_device_meshes: np.ndarray, - output_shard_specs: Tuple[str, ...], + output_shard_specs: tuple[str, ...], ): assert len(input_device_meshes) == len(input_shard_specs) assert len(output_device_meshes) == len(output_shard_specs) @@ -1146,6 +1145,7 @@ def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64): for onnx_func, np_func in zip( [distributed_reduce_sum_instance, distributed_reduce_max_instance, distributed_reduce_mean_instance], [np.sum, np.maximum.reduce, np.mean], + strict=False, ): data = np.random.randint(4, size=shape).astype(dtype) expected = np_func(data, axis=axes, keepdims=bool(keepdims)) diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index ce04dff2aecb0..5ab2fe8939f6a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import unittest -from typing import Dict, List import numpy as np from helper import get_name @@ -14,7 +13,7 @@ class CudaGraphHelper: def __init__( self, ort_session: onnxrt.InferenceSession, - input_and_output_shape: Dict[str, List[int]], + input_and_output_shape: dict[str, list[int]], device_id: int = 0, ): self.input_names = [input.name for input in ort_session.get_inputs()] @@ -52,7 +51,7 @@ def get_io_numpy_type_map(self, ort_session: onnxrt.InferenceSession): return name_to_numpy_type - def update_inputs(self, inputs: Dict[str, np.ndarray]): + def update_inputs(self, inputs: dict[str, np.ndarray]): for input_name in self.input_names: self.io_ort_value[input_name].update_inplace(inputs[input_name]) diff --git a/onnxruntime/test/python/onnxruntime_test_python_dmlgraph.py b/onnxruntime/test/python/onnxruntime_test_python_dmlgraph.py index 29292c2a777b1..033eae1cb4c8d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_dmlgraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_dmlgraph.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import unittest -from typing import Dict, List import numpy as np from helper import get_name @@ -14,7 +13,7 @@ class DmlGraphHelper: def __init__( self, ort_session: onnxrt.InferenceSession, - input_and_output_shape: Dict[str, List[int]], + input_and_output_shape: dict[str, list[int]], device_id: int = 0, ): self.input_names = [input.name for input in ort_session.get_inputs()] @@ -52,7 +51,7 @@ def get_io_numpy_type_map(self, ort_session: onnxrt.InferenceSession): return name_to_numpy_type - def update_inputs(self, inputs: Dict[str, np.ndarray]): + def update_inputs(self, inputs: dict[str, np.ndarray]): for input_name in self.input_names: self.io_ort_value[input_name].update_inplace(inputs[input_name]) diff --git a/onnxruntime/test/python/onnxruntime_test_python_nested_control_flow_op.py b/onnxruntime/test/python/onnxruntime_test_python_nested_control_flow_op.py index bf354ad9f9e10..0a311245dd2b5 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_nested_control_flow_op.py +++ b/onnxruntime/test/python/onnxruntime_test_python_nested_control_flow_op.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. import unittest +from collections.abc import Sequence from copy import deepcopy -from typing import Optional, Sequence, Tuple import numpy as np from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, checker, helper @@ -31,7 +31,7 @@ def make_optional_tensor_value_info(name: str, elem_type: int, shape: Sequence[i return vi -def make_optional_vi(vi: ValueInfoProto, name: Optional[str] = None) -> ValueInfoProto: +def make_optional_vi(vi: ValueInfoProto, name: str | None = None) -> ValueInfoProto: """Makes a copy of `vi` with optional type.""" name = name or vi.name + ".opt" vi_type = vi.type.tensor_type @@ -40,7 +40,7 @@ def make_optional_vi(vi: ValueInfoProto, name: Optional[str] = None) -> ValueInf return opt_vi -def make_const(vi: ValueInfoProto, name: str, value: int = 0) -> Tuple[ValueInfoProto, NodeProto, TensorProto]: +def make_const(vi: ValueInfoProto, name: str, value: int = 0) -> tuple[ValueInfoProto, NodeProto, TensorProto]: """Creates a constant 1D tensor from `vi`.""" const_vi = make_vi_like(vi, name) const_shape = [d.dim_value for d in vi.type.tensor_type.shape.dim] diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 2f8fb84c4c651..92d6d758eef4d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -120,7 +120,7 @@ def _check_shapes(self, graph, inferred_graph, vis): # type: (GraphProto, Graph vis_names = {x.name for x in vis} inferred_vis_names = {x.name for x in inferred_vis} assert vis_names == inferred_vis_names, (vis_names, inferred_vis_names) - for vi, inferred_vi in zip(vis, inferred_vis): + for vi, inferred_vi in zip(vis, inferred_vis, strict=False): assert vi == inferred_vi, f"\n{vi}\n{inferred_vi}\n" raise AssertionError() diff --git a/onnxruntime/test/python/onnxruntime_test_scatternd.py b/onnxruntime/test/python/onnxruntime_test_scatternd.py index e75c04dfb9965..42f706d1eec0f 100644 --- a/onnxruntime/test/python/onnxruntime_test_scatternd.py +++ b/onnxruntime/test/python/onnxruntime_test_scatternd.py @@ -19,7 +19,7 @@ def has_cuda(): return "CUDAExecutionProvider" in available_providers -def ignore_warnings(warns: typing.List[Warning]) -> typing.Callable: +def ignore_warnings(warns: list[Warning]) -> typing.Callable: def wrapper(fct): if warns is None: raise AssertionError(f"warns cannot be None for '{fct}'.") diff --git a/onnxruntime/test/python/quantization/test_calibration.py b/onnxruntime/test/python/quantization/test_calibration.py index 5856ec44bc85f..60c5f9d404258 100644 --- a/onnxruntime/test/python/quantization/test_calibration.py +++ b/onnxruntime/test/python/quantization/test_calibration.py @@ -358,9 +358,9 @@ def test_compute_data(self): rmin = np.minimum(rmin, np.amin(output, axis=1)) rmax = np.maximum(rmax, np.amax(output, axis=1)) - min_max_pairs = list(zip(rmin, rmax)) + min_max_pairs = list(zip(rmin, rmax, strict=False)) output_names = [infer_session.get_outputs()[i].name for i in range(len(infer_session.get_outputs()))] - output_min_max_dict = dict(zip(output_names, min_max_pairs)) + output_min_max_dict = dict(zip(output_names, min_max_pairs, strict=False)) for output_name, min_max in output_min_max_dict.items(): self.assertEqual(min_max, tensors_range[output_name].range_value) @@ -521,9 +521,9 @@ def test_compute_data_per_channel(self): rmin = np.minimum(rmin, np.amin(output, axis=-1)) rmax = np.maximum(rmax, np.amax(output, axis=-1)) - min_max_pairs = list(zip(rmin, rmax)) + min_max_pairs = list(zip(rmin, rmax, strict=False)) output_names = [infer_session.get_outputs()[i].name for i in range(len(infer_session.get_outputs()))] - output_min_max_dict = dict(zip(output_names, min_max_pairs)) + output_min_max_dict = dict(zip(output_names, min_max_pairs, strict=False)) for output_name, min_max in output_min_max_dict.items(): np.testing.assert_equal(min_max, tensors_range[output_name].range_value) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 292dc50124c16..ed0c65cba78ac 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -9,7 +9,6 @@ import unittest from importlib.util import find_spec from pathlib import Path -from typing import Dict, Tuple, Union import numpy as np import onnx @@ -28,7 +27,7 @@ def setUpClass(cls): def tearDownClass(cls): cls._tmp_model_dir.cleanup() - def fill_int4_data(self, shape: Union[int, Tuple[int, ...]], symmetric: bool) -> np.ndarray: + def fill_int4_data(self, shape: int | tuple[int, ...], symmetric: bool) -> np.ndarray: line = np.zeros(shape) line = line.reshape(-1) @@ -54,7 +53,7 @@ def fill_int4_data(self, shape: Union[int, Tuple[int, ...]], symmetric: bool) -> def input_feeds( self, n: int, - name2shape: Dict[str, Union[int, Tuple[int, ...]]], + name2shape: dict[str, int | tuple[int, ...]], low: int = -1, high: int = 2, dtype: type = np.float32, @@ -79,7 +78,7 @@ def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> Non initializers = [] def make_matmul( - input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str, node_name: str + input_name, weight_shape: int | tuple[int, ...], weight_name: str, output_name: str, node_name: str ): weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) @@ -137,7 +136,7 @@ def construct_model_gather( initializers = [] def make_gather( - indices_name, data_shape: Union[int, Tuple[int, ...]], data_name: str, output_name: str, node_name: str + indices_name, data_shape: int | tuple[int, ...], data_name: str, output_name: str, node_name: str ): weight_data = self.fill_int4_data(data_shape, symmetric).astype( np.float32 if tdata == TensorProto.FLOAT else np.float16 @@ -184,8 +183,8 @@ def quant_test( block_size: int, is_symmetric: bool, quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, - op_types_to_quantize: Tuple[str, ...] = ("MatMul",), - quant_axes: Tuple[Tuple[str, int], ...] = (("MatMul", 0), ("Gather", 1)), + op_types_to_quantize: tuple[str, ...] = ("MatMul",), + quant_axes: tuple[tuple[str, int], ...] = (("MatMul", 0), ("Gather", 1)), rtol: float = 0.01, atol: float = 0.05, ): diff --git a/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py index 88432d75c653e..d32abc1476600 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py @@ -9,7 +9,6 @@ import unittest from importlib.util import find_spec from pathlib import Path -from typing import Dict, Tuple, Union import numpy as np import onnx @@ -67,7 +66,7 @@ def setUpClass(cls): def tearDownClass(cls): cls._tmp_model_dir.cleanup() - def fill_bnb4_data(self, shape: Tuple[int, int], quant_type: int) -> np.ndarray: + def fill_bnb4_data(self, shape: tuple[int, int], quant_type: int) -> np.ndarray: rows, cols = shape line = np.zeros(shape) line = line.reshape(-1) @@ -84,7 +83,7 @@ def fill_bnb4_data(self, shape: Tuple[int, int], quant_type: int) -> np.ndarray: line = line.reshape(cols, rows).transpose() return line.reshape(shape) - def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: + def input_feeds(self, n: int, name2shape: dict[str, int | tuple[int, ...]]) -> TestDataFeeds: input_data_list = [] for _i in range(n): inputs = {} @@ -104,7 +103,7 @@ def construct_model_matmul(self, output_model_path: str, quant_type: int) -> Non output_name = "output" initializers = [] - def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + def make_matmul(input_name, weight_shape: int | tuple[int, ...], weight_name: str, output_name: str): weight_data = self.fill_bnb4_data(weight_shape, quant_type).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) return onnx.helper.make_node( diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 755c7fae5e3e8..28dc8f4b7dee7 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -54,7 +54,7 @@ def construct_model_pad( input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, pad_input_shape) pad_dims_initializer = helper.make_tensor("pad_dims", TensorProto.INT64, [2 * rank], pad_dims) - output_shape = [sum(e) for e in list(zip(pad_input_shape, pad_dims[:rank], pad_dims[rank:]))] + output_shape = [sum(e) for e in list(zip(pad_input_shape, pad_dims[:rank], pad_dims[rank:], strict=False))] output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape) inputs = ["input", "pad_dims"] @@ -108,7 +108,7 @@ def construct_model_conv_pad( identity_node = helper.make_node("Identity", ["conv_output"], ["identity_out"], name="IdentityNode") pad_dims_initializer = helper.make_tensor("pad_dims", TensorProto.INT64, [2 * rank], pad_dims) - output_shape = [sum(e) for e in list(zip(pad_input_shape, pad_dims[:rank], pad_dims[rank:]))] + output_shape = [sum(e) for e in list(zip(pad_input_shape, pad_dims[:rank], pad_dims[rank:], strict=False))] output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape) pad_inputs = ["conv_output", "pad_dims"] initializers = [conv_weight_initializer, pad_dims_initializer] @@ -385,7 +385,7 @@ def construct_edge_case_model( identity_node = helper.make_node("Identity", ["conv_output"], ["identity_out"], name="IdentityNode") pad_dims_initializer = helper.make_tensor("pad_dims", TensorProto.INT64, [2 * rank], pad_dims) - output_shape = [sum(e) for e in list(zip(pad_input_shape, pad_dims[:rank], pad_dims[rank:]))] + output_shape = [sum(e) for e in list(zip(pad_input_shape, pad_dims[:rank], pad_dims[rank:], strict=False))] output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape) pad_inputs = ["conv_output", "pad_dims"] initializers = [conv_weight_initializer, pad_dims_initializer] diff --git a/onnxruntime/test/python/quantization/test_qdq_loss_debug.py b/onnxruntime/test/python/quantization/test_qdq_loss_debug.py index e9108f157f953..5d70641547eae 100644 --- a/onnxruntime/test/python/quantization/test_qdq_loss_debug.py +++ b/onnxruntime/test/python/quantization/test_qdq_loss_debug.py @@ -9,7 +9,6 @@ import tempfile import unittest from pathlib import Path -from typing import Dict, List import numpy as np import onnx @@ -108,7 +107,7 @@ def rewind(self): def augment_model_collect_activations( model_path: str, augmented_model_path: str, data_reader: TestDataReader -) -> Dict[str, List[np.ndarray]]: +) -> dict[str, list[np.ndarray]]: modify_model_output_intermediate_tensors(model_path, augmented_model_path) tensor_dict = collect_activations(augmented_model_path, data_reader) @@ -149,12 +148,12 @@ def test_saved_tensors_match_internal_tensors(self): output_dict = {} output_info = infer_session.get_outputs() for batch in oracle_outputs: - for output, output_data in zip(output_info, batch): + for output, output_data in zip(output_info, batch, strict=False): output_dict.setdefault(output.name, []).append(output_data) for output_name, model_outputs in output_dict.items(): test_outputs = tensor_dict[output_name] - for expected, actual in zip(model_outputs, test_outputs): + for expected, actual in zip(model_outputs, test_outputs, strict=False): exp = expected.reshape(-1) act = actual.reshape(-1) np.testing.assert_equal(exp, act) diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 5617a424cf4dc..be10575b535e4 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -420,7 +420,7 @@ def test_qdq_overrides_per_channel2(self): ) self.assertEqual(wgt_zp.data_type, quant_type.tensor_type) - for index, (zp, scale) in enumerate(zip(wgt_zp.int32_data, wgt_sc.float_data)): + for index, (zp, scale) in enumerate(zip(wgt_zp.int32_data, wgt_sc.float_data, strict=False)): wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType( wgt_zp.data_type, symmetric=True, # per-channel is always symmetric diff --git a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py index 5e20d6b4e692a..96a9aaad3c331 100644 --- a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py +++ b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py @@ -43,7 +43,10 @@ def to_numpy(tensor): assert len(outputs) == len(ort_outs), "number of outputs differ" # compare onnxruntime and PyTorch results - [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)] + [ + np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) + for out, ort_out in zip(outputs, ort_outs, strict=False) + ] # These set of tests verify ONNX model export and compares outputs between diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index 5cef4ae863a0e..41dbdf255f35c 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -7,13 +7,11 @@ Benchmark performance of GroupQueryAttention. """ -from typing import Optional - import torch from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention -def get_plot_algos(sm: int, local_window_size: Optional[int]): +def get_plot_algos(sm: int, local_window_size: int | None): # GQA with local windows only works in sm=8x if sm >= 80 and local_window_size: return { @@ -37,7 +35,7 @@ def plot_prompt_performance( kv_num_heads: int, head_size: int, max_seq_len: int, - local_window_size: Optional[int] = None, + local_window_size: int | None = None, use_smooth_softmax: bool = False, ): import triton @@ -70,7 +68,7 @@ def benchmark( num_heads: int, kv_num_heads: int, head_size: int, - local_window_size: Optional[int] = None, + local_window_size: int | None = None, use_smooth_softmax: bool = False, device="cuda", ): @@ -107,7 +105,7 @@ def plot_token_performance( kv_num_heads: int, head_size: int, max_seq_len: int, - local_window_size: Optional[int] = None, + local_window_size: int | None = None, use_smooth_softmax: bool = False, ): import triton @@ -140,7 +138,7 @@ def benchmark( num_heads: int, kv_num_heads: int, head_size: int, - local_window_size: Optional[int] = None, + local_window_size: int | None = None, use_smooth_softmax: bool = False, device="cuda", ): diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py index 79cc8e41bf343..97ff8f4b21a68 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py @@ -1,7 +1,6 @@ import argparse import os import time -from typing import Optional import torch from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention @@ -36,7 +35,7 @@ def benchmark( max_seq_len: int, sequence_length: int = 1, past_sequence_length: int = 0, - local_window_size: Optional[int] = None, + local_window_size: int | None = None, use_smooth_softmax: bool = False, model_name: str = "Llama3-8B", ): diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index d922f153b4b91..d5bcabe0bf147 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -23,10 +23,10 @@ import sys import threading import time +from collections.abc import Callable from contextlib import nullcontext from datetime import datetime from enum import IntEnum -from typing import Callable, Dict, List, Optional, Tuple import torch import torch.utils.benchmark as benchmark @@ -56,7 +56,7 @@ def convert(format_str: str) -> int: return names.index(format_str) @staticmethod - def get_name_list() -> List[str]: + def get_name_list() -> list[str]: return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"] @@ -95,7 +95,7 @@ def __init__( max_cache_sequence_length=None, scale: float = 0.0, provider="CPUExecutionProvider", - device: Optional[torch.device] = None, + device: torch.device | None = None, enable_cuda_graph: bool = False, dtype=torch.float, use_kv_cache: bool = False, @@ -205,7 +205,7 @@ def __repr__(self): ) def shape_dict(self, input_format=None): - shapes: Dict[str, Tuple] = { + shapes: dict[str, tuple] = { "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), } @@ -272,7 +272,7 @@ def shape_dict(self, input_format=None): return shapes def symbolic_shape_dict(self, input_format=None): - shapes: Dict[str, Tuple] = { + shapes: dict[str, tuple] = { "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), } @@ -346,7 +346,7 @@ def right_side_padding_masks(self): ) if self.mask_format != AttentionMaskFormat.Mask_None: - for i, (m, n) in enumerate(zip(self.mask_index_q, self.mask_index_kv)): + for i, (m, n) in enumerate(zip(self.mask_index_q, self.mask_index_kv, strict=False)): q_mask[i, :, m:, :] = False k_mask[i, :, n:, :] = False mask[i, :, m:, :] = False @@ -660,7 +660,7 @@ def run_torch_sdpa( has_mask: bool = False, mask_dim: int = 2, mask_dtype=torch.bool, - backend: Optional[int] = None, + backend: int | None = None, repeats: int = 100, ): q_shape = (batch_size, num_heads, q_seq_len, head_size) diff --git a/onnxruntime/test/python/transformers/bert_model_generator.py b/onnxruntime/test/python/transformers/bert_model_generator.py index a84137f092e64..0bb71bd8736d4 100644 --- a/onnxruntime/test/python/transformers/bert_model_generator.py +++ b/onnxruntime/test/python/transformers/bert_model_generator.py @@ -5,7 +5,6 @@ # -------------------------------------------------------------------------- import math -from typing import List import numpy as np import onnx @@ -13,7 +12,7 @@ from packaging import version -def float_tensor(name: str, shape: List[int], random=False): +def float_tensor(name: str, shape: list[int], random=False): low = 0.0 high = 1.0 total_elements = 1 diff --git a/onnxruntime/test/python/transformers/conformer_model_generator.py b/onnxruntime/test/python/transformers/conformer_model_generator.py index 71e4f2b63cf4f..4e76478bfb649 100644 --- a/onnxruntime/test/python/transformers/conformer_model_generator.py +++ b/onnxruntime/test/python/transformers/conformer_model_generator.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- -from typing import List import numpy as np import onnx @@ -13,7 +12,7 @@ # Adapted from bert_model_generator.py -def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False): +def get_tensor_and_weight(name: str, shape: list[int], random=False, zeros=False): low = 0.0 high = 1.0 total_elements = 1 diff --git a/onnxruntime/test/python/transformers/gpt2_model_generator.py b/onnxruntime/test/python/transformers/gpt2_model_generator.py index 0865c87b70da7..74136c2b8bc61 100644 --- a/onnxruntime/test/python/transformers/gpt2_model_generator.py +++ b/onnxruntime/test/python/transformers/gpt2_model_generator.py @@ -5,7 +5,6 @@ # -------------------------------------------------------------------------- import math -from typing import List # noqa: F401 import numpy import onnx diff --git a/onnxruntime/test/python/transformers/rotary_flash.py b/onnxruntime/test/python/transformers/rotary_flash.py index 4329b2c1a6057..a033805ec0d5e 100644 --- a/onnxruntime/test/python/transformers/rotary_flash.py +++ b/onnxruntime/test/python/transformers/rotary_flash.py @@ -1,8 +1,6 @@ # Copyright (c) 2023, Tri Dao. -from typing import Optional, Tuple, Union - import torch import triton import triton.language as tl @@ -142,9 +140,9 @@ def apply_rotary( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, + seqlen_offsets: int | torch.Tensor = 0, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, interleaved=False, inplace=False, conjugate=False, @@ -265,9 +263,9 @@ def forward( sin, interleaved=False, inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, + seqlen_offsets: int | torch.Tensor = 0, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, ): out = apply_rotary( x, @@ -321,9 +319,9 @@ def apply_rotary_emb( sin, interleaved=False, inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, + seqlen_offsets: int | torch.Tensor = 0, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, ): """ Arguments: @@ -360,7 +358,7 @@ def forward( cos_k=None, sin_k=None, interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, + seqlen_offsets: int | torch.Tensor = 0, ): batch, seqlen, three, nheads, headdim = qkv.shape assert three == 3 @@ -432,7 +430,7 @@ def apply_rotary_emb_qkv_( cos_k=None, sin_k=None, interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, + seqlen_offsets: int | torch.Tensor = 0, ): """ Arguments: @@ -453,7 +451,7 @@ def apply_rotary_emb_qkv_( class ApplyRotaryEmbKV(torch.autograd.Function): @staticmethod - def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): + def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: int | torch.Tensor = 0): batch, seqlen, two, nheads, headdim = kv.shape assert two == 2 k = kv[:, :, 0] @@ -491,7 +489,7 @@ def apply_rotary_emb_kv_( cos, sin, interleaved=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, + seqlen_offsets: int | torch.Tensor = 0, ): """ Arguments: @@ -623,10 +621,10 @@ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): def forward( self, qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: Union[int, torch.Tensor] = 0, - max_seqlen: Optional[int] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + kv: torch.Tensor | None = None, + seqlen_offset: int | torch.Tensor = 0, + max_seqlen: int | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim) diff --git a/onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py b/onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py index 431ae21cd5eaf..c4c136981e7a9 100644 --- a/onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py +++ b/onnxruntime/test/python/transformers/test_gemmfastgelu_fusion.py @@ -6,7 +6,6 @@ import os import unittest -from typing import List import numpy as np import onnx @@ -33,7 +32,7 @@ opsets = [onnxdomain, msdomain] -def float_tensor(name: str, shape: List[int], random=False): +def float_tensor(name: str, shape: list[int], random=False): low = 0.0 high = 1.0 total_elements = 1 diff --git a/onnxruntime/test/python/transformers/test_group_norm.py b/onnxruntime/test/python/transformers/test_group_norm.py index bf295a65c8b53..7a04df8b39c0d 100644 --- a/onnxruntime/test/python/transformers/test_group_norm.py +++ b/onnxruntime/test/python/transformers/test_group_norm.py @@ -7,7 +7,6 @@ from dataclasses import dataclass from enum import Enum from time import perf_counter -from typing import Optional, Tuple import numpy import torch @@ -215,11 +214,11 @@ def group_norm_ort( src: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, - skip: Optional[torch.Tensor], - bias: Optional[torch.Tensor], + skip: torch.Tensor | None, + bias: torch.Tensor | None, config: GroupNormConfig, measure_latency=False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]: +) -> tuple[torch.Tensor, torch.Tensor | None, float | None]: onnx_model_str = create_group_norm_graph(config) ort_session = InferenceSession(onnx_model_str, providers=["CUDAExecutionProvider"]) @@ -276,10 +275,10 @@ def group_norm_torch( src: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, - skip: Optional[torch.Tensor], - bias: Optional[torch.Tensor], + skip: torch.Tensor | None, + bias: torch.Tensor | None, config: GroupNormConfig, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor | None]: add_out = src if skip is not None: diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 6f396f35f7146..dc19e3ec95243 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -11,7 +11,6 @@ import itertools import os import unittest -from typing import Dict, List, Optional import numpy import torch @@ -102,9 +101,9 @@ def attention_reference( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - scale: Optional[float] = None, - attn_bias: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, + scale: float | None = None, + attn_bias: torch.Tensor | None = None, + mask: torch.Tensor | None = None, verbose: bool = False, ) -> torch.Tensor: """Reference implementation of SDPA @@ -171,14 +170,14 @@ def attention_reference( def mha_with_past_reference( config: MultiHeadAttentionConfig, - past_k: Optional[torch.Tensor], - past_v: Optional[torch.Tensor], + past_k: torch.Tensor | None, + past_v: torch.Tensor | None, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - scale: Optional[float] = None, - attn_bias: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, + scale: float | None = None, + attn_bias: torch.Tensor | None = None, + mask: torch.Tensor | None = None, ): assert config.kv_sequence_length == config.sequence_length assert config.use_kv_cache @@ -648,7 +647,7 @@ def parity_check_mha( def parity_check_mha_multi_threading( - test_inputs: List[Dict], + test_inputs: list[dict], rtol: float = 1e-3, atol: float = 1e-3, attention_kernel=SdpaKernel.DEFAULT, diff --git a/onnxruntime/test/python/transformers/test_parity_decoder_attention.py b/onnxruntime/test/python/transformers/test_parity_decoder_attention.py index e870e7f95fcee..8b4a68402f995 100644 --- a/onnxruntime/test/python/transformers/test_parity_decoder_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_decoder_attention.py @@ -10,7 +10,6 @@ # license information. # ------------------------------------------------------------------------- -from typing import List, Optional, Tuple import numpy import torch @@ -118,7 +117,7 @@ def forward( self, query, key, - layer_state: Optional[List[Tensor]], + layer_state: list[Tensor] | None, encoder_decoder_attention: bool, use_past=torch.tensor(False), # noqa: B008 ): @@ -182,13 +181,13 @@ def forward( self, query, key: Tensor, - key_padding_mask: Optional[Tensor] = None, - layer_state: Optional[List[Tensor]] = None, - attn_mask: Optional[Tensor] = None, + key_padding_mask: Tensor | None = None, + layer_state: list[Tensor] | None = None, + attn_mask: Tensor | None = None, output_attentions: bool = False, use_past=torch.tensor(False), # noqa: B008 has_key_padding_mask: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: """Input shape: Time(SeqLen) x Batch x Channel""" static_kv: bool = self.encoder_decoder_attention tgt_len, bsz, embed_dim = query.size() @@ -241,13 +240,13 @@ def ort_forward( self, query, key: Tensor, - key_padding_mask: Optional[Tensor] = None, - layer_state: Optional[List[Tensor]] = None, - attn_mask: Optional[Tensor] = None, + key_padding_mask: Tensor | None = None, + layer_state: list[Tensor] | None = None, + attn_mask: Tensor | None = None, output_attentions: bool = False, use_past=torch.tensor(False), # noqa: B008 has_key_padding_mask: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: """Input shape: Time(SeqLen) x Batch x Channel""" # For readability static_kv = bool(self.encoder_decoder_attention) diff --git a/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py index 7bca48c29019e..89ef0342fab74 100644 --- a/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py +++ b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py @@ -6,7 +6,6 @@ import os import sys import unittest -from typing import List import numpy as np import onnx @@ -23,7 +22,7 @@ from onnxruntime.transformers.optimizer import optimize_model -def float_tensor(name: str, shape: List[int], random=False): +def float_tensor(name: str, shape: list[int], random=False): low = 0.0 high = 1.0 total_elements = 1 @@ -113,7 +112,7 @@ def create_inputs_and_outputs(self, model_type: str = ""): outputs.append(helper.make_tensor_value_info("past_seq_len_plus_zero", TensorProto.FLOAT, [1])) return inputs, outputs - def create_fused_model(self, interleaved: bool, initializers: List[TensorProto]): + def create_fused_model(self, interleaved: bool, initializers: list[TensorProto]): inputs, outputs = self.create_inputs_and_outputs() rope_node = helper.make_node( @@ -385,7 +384,7 @@ def create_apply_rope_path(self): return x_half_shape_nodes + rotate_half_nodes + x_embed_nodes - def create_test_model(self, model_type: str, use_redundant_squeeze_ops: bool, initializers: List[TensorProto]): + def create_test_model(self, model_type: str, use_redundant_squeeze_ops: bool, initializers: list[TensorProto]): apply_rope_nodes = self.create_apply_rope_path() cache_nodes = self.create_cache_path(model_type, use_redundant_squeeze_ops) inputs, outputs = self.create_inputs_and_outputs(model_type) diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py index aba0ccdac2e6e..0ec5c684532cc 100644 --- a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -6,7 +6,6 @@ import os import sys import unittest -from typing import List import numpy as np import onnx @@ -23,7 +22,7 @@ from onnxruntime.transformers.optimizer import optimize_model -def float_tensor(name: str, shape: List[int], random=False): +def float_tensor(name: str, shape: list[int], random=False): low = 0.0 high = 1.0 total_elements = 1 @@ -157,8 +156,8 @@ def create_rotary_embeddings( is_fused: bool, model_type: str, interleaved: bool, - inputs: List[TensorProto], - initializers: List[TensorProto], + inputs: list[TensorProto], + initializers: list[TensorProto], ): def get_first_rope_input(node_type: str): if is_fused or model_type == "llama2_msft": @@ -974,7 +973,7 @@ def create_qkv_path(self, model_type: str): return qkv_nodes + [transpose_qkv_node, reshape_qkv_2_node] # noqa: RUF005 - def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[NodeProto]): + def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: list[NodeProto]): # Create initial shape paths shape_0_node = helper.make_node( "Shape", @@ -1097,7 +1096,7 @@ def create_end_nodes(self, model_type): ) return [matmul_o_node, end_node] - def create_fused_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]): + def create_fused_model(self, model_type: str, interleaved: bool, initializers: list[TensorProto]): inputs, outputs = self.create_inputs_and_outputs(model_type) matmul_nodes = self.create_matmul_nodes(True, model_type=model_type) rope_nodes = self.create_rotary_embeddings(True, model_type, interleaved, inputs, initializers) @@ -1134,7 +1133,7 @@ def create_fused_model(self, model_type: str, interleaved: bool, initializers: L model = helper.make_model(graph, opset_imports=[opset_import]) return model - def create_test_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]): + def create_test_model(self, model_type: str, interleaved: bool, initializers: list[TensorProto]): inputs, outputs = self.create_inputs_and_outputs(model_type) matmul_nodes = self.create_matmul_nodes(False, model_type) rope_nodes = self.create_rotary_embeddings(False, model_type, interleaved, inputs, initializers) diff --git a/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py index e86bdda7baffb..95639958dbb2e 100644 --- a/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py +++ b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py @@ -5,7 +5,6 @@ import os import unittest -from typing import List import numpy as np import onnx @@ -22,7 +21,7 @@ from onnxruntime.transformers.optimizer import optimize_model -def float_tensor(name: str, shape: List[int], random=False): +def float_tensor(name: str, shape: list[int], random=False): low = 0.0 high = 1.0 total_elements = 1 @@ -115,7 +114,7 @@ def create_inputs_and_outputs(self, start_node_type: str): ] return inputs, outputs, start_node - def create_fused_model(self, start_node_type: str, initializers: List[TensorProto]): + def create_fused_model(self, start_node_type: str, initializers: list[TensorProto]): inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type) sln_node = helper.make_node( @@ -139,7 +138,7 @@ def create_fused_model(self, start_node_type: str, initializers: List[TensorProt return model # Notation follows https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary - def create_test_model(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]): + def create_test_model(self, start_node_type: str, first_parent_idx: int, initializers: list[TensorProto]): end_node = helper.make_node( "Mul", inputs=["scale", "Normalized"] if first_parent_idx == 1 else ["Normalized", "scale"], @@ -197,7 +196,7 @@ def create_test_model(self, start_node_type: str, first_parent_idx: int, initial model = helper.make_model(graph, opset_imports=[opset_import]) return model - def check_models(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]): + def check_models(self, start_node_type: str, first_parent_idx: int, initializers: list[TensorProto]): expected_model_filename = "expected_model.onnx" expected_model = self.create_fused_model(start_node_type, initializers) onnx.save(expected_model, expected_model_filename) diff --git a/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py b/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py index 5b3a3f18cd744..a55ff5aa91519 100644 --- a/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py +++ b/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py @@ -6,7 +6,6 @@ import os import unittest -from typing import Dict, List import numpy as np import onnx @@ -21,7 +20,7 @@ from onnxruntime.transformers.optimizer import optimize_model -def float_tensor(name: str, shape: List[int], random=False): +def float_tensor(name: str, shape: list[int], random=False): low = 0.0 high = 1.0 total_elements = 1 @@ -35,9 +34,9 @@ class TestFusion(unittest.TestCase): def verify_skip_layer_norm_fusion( self, model_path: str, - expected_counter: Dict[str, int], - expected_inputs: List[str], - expected_outputs: List[str], + expected_counter: dict[str, int], + expected_inputs: list[str], + expected_outputs: list[str], ): options = FusionOptions("bert") optimized_model = optimize_model(model_path, optimization_options=options, opt_level=0) diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 774761afddc8a..eac6bbdc3dd12 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -9,7 +9,6 @@ import math import unittest -from typing import Optional, Union import torch from benchmark_mha import InputFormats @@ -34,7 +33,7 @@ def __init__( num_heads: int, kv_num_heads: int, head_size: int, - softmax_scale: Optional[float], + softmax_scale: float | None, do_rotary: bool, rotary_interleaved: bool, provider: str = "CUDAExecutionProvider", @@ -602,8 +601,8 @@ def group_query_attention_reference( key: Tensor, value: Tensor, config: GroupQueryAttentionConfig, - scale: Optional[float] = None, - mask: Optional[Tensor] = None, + scale: float | None = None, + mask: Tensor | None = None, ): if scale is None: scale = 1.0 / (config.head_size**0.5) @@ -704,7 +703,7 @@ def infer(self): def create_ort_session( - config: Union[SparseAttentionConfig, GroupQueryAttentionConfig], session_options=None, enable_cuda_graph=False + config: SparseAttentionConfig | GroupQueryAttentionConfig, session_options=None, enable_cuda_graph=False ) -> CudaSession: if isinstance(config, SparseAttentionConfig): onnx_model_str = create_sparse_attention_onnx_model(config) diff --git a/onnxruntime/test/python/transformers/whisper_model_generator.py b/onnxruntime/test/python/transformers/whisper_model_generator.py index 71d1a4cbdceeb..f1a692b7694cb 100644 --- a/onnxruntime/test/python/transformers/whisper_model_generator.py +++ b/onnxruntime/test/python/transformers/whisper_model_generator.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- -from typing import List import numpy as np import onnx @@ -13,7 +12,7 @@ # Adapted from bert_model_generator.py -def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False): +def get_tensor_and_weight(name: str, shape: list[int], random=False, zeros=False): low = 0.0 high = 1.0 total_elements = 1 diff --git a/onnxruntime/test/testdata/CNTK/gen.py b/onnxruntime/test/testdata/CNTK/gen.py index 5a3ca461f471a..b5f39bcb448f9 100644 --- a/onnxruntime/test/testdata/CNTK/gen.py +++ b/onnxruntime/test/testdata/CNTK/gen.py @@ -23,7 +23,7 @@ def SaveTensorProto(file_path, variable, data, name): # noqa: N802 def SaveData(test_data_dir, prefix, variables, data_list, name_replacements=None): # noqa: N802 if isinstance(data_list, np.ndarray): data_list = [data_list] - for (i, d), v in zip(enumerate(data_list), variables): + for (i, d), v in zip(enumerate(data_list), variables, strict=False): SaveTensorProto( os.path.join(test_data_dir, f"{prefix}_{i}.pb"), v, diff --git a/onnxruntime/test/testdata/sparse_initializer_as_output.py b/onnxruntime/test/testdata/sparse_initializer_as_output.py index b10c84ccc1723..3a7e47910783e 100644 --- a/onnxruntime/test/testdata/sparse_initializer_as_output.py +++ b/onnxruntime/test/testdata/sparse_initializer_as_output.py @@ -2,7 +2,7 @@ import os # noqa: F401 import sys import traceback -from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Tuple, TypeVar, Union, cast # noqa: F401 +from collections.abc import Callable, Sequence # noqa: F401 import numpy as np import onnx diff --git a/onnxruntime/test/testdata/sparse_to_dense_matmul.py b/onnxruntime/test/testdata/sparse_to_dense_matmul.py index 57a15ba72308e..bbc7f0bc0e88f 100644 --- a/onnxruntime/test/testdata/sparse_to_dense_matmul.py +++ b/onnxruntime/test/testdata/sparse_to_dense_matmul.py @@ -2,7 +2,7 @@ import os # noqa: F401 import sys import traceback -from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Tuple, TypeVar, Union, cast # noqa: F401 +from collections.abc import Callable, Sequence # noqa: F401 import numpy as np # noqa: F401 import onnx diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 31591c0156b14..c304d2f262650 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -6,7 +6,6 @@ import os import pathlib from enum import Enum -from typing import List, Optional, Union import onnx @@ -40,18 +39,18 @@ class OptimType(Enum): def generate_artifacts( - model: Union[onnx.ModelProto, str], - requires_grad: Optional[List[str]] = None, - frozen_params: Optional[List[str]] = None, - loss: Optional[Union[LossType, onnxblock.Block]] = None, - optimizer: Optional[Union[OptimType, onnxblock.Block]] = None, - artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None, + model: onnx.ModelProto | str, + requires_grad: list[str] | None = None, + frozen_params: list[str] | None = None, + loss: LossType | onnxblock.Block | None = None, + optimizer: OptimType | onnxblock.Block | None = None, + artifact_directory: str | bytes | os.PathLike | None = None, prefix: str = "", ort_format: bool = False, - custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, - additional_output_names: Optional[List[str]] = None, + custom_op_library: str | bytes | os.PathLike | None = None, + additional_output_names: list[str] | None = None, nominal_checkpoint: bool = False, - loss_input_names: Optional[List[str]] = None, + loss_input_names: list[str] | None = None, ) -> None: """Generates artifacts required for training with ORT training api. diff --git a/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py b/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py index 5ab79b3712472..9ea12753a254b 100644 --- a/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py +++ b/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py @@ -1,6 +1,7 @@ import io +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Optional, Union # noqa: F401 +from typing import Any import torch from torch.onnx import TrainingMode @@ -15,7 +16,7 @@ def export_gradient_graph( loss_fn: Callable[[Any, Any], Any], example_input: torch.Tensor, example_labels: torch.Tensor, - gradient_graph_path: Union[Path, str], + gradient_graph_path: Path | str, opset_version=12, ) -> None: r""" @@ -45,7 +46,7 @@ def export_gradient_graph( class WrapperModule(torch.nn.Module): def forward(self, model_input, expected_labels, *model_params): - for param, set_param in zip(model.parameters(), model_params): + for param, set_param in zip(model.parameters(), model_params, strict=False): param.data = set_param.data output = model(model_input) loss = loss_fn(output, expected_labels) diff --git a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py index 42743a4200d17..fd10e6b65fb84 100644 --- a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import List, Union import onnx @@ -43,7 +42,7 @@ def generate_graph_name(token: str) -> str: return f"onnx::{token}::{_get_token()}" -def register_graph_outputs(model: onnx.ModelProto, output_names: Union[List[str], str]) -> None: +def register_graph_outputs(model: onnx.ModelProto, output_names: list[str] | str) -> None: """Register the given output names as graph outputs. The graph outputs shape information is extracted from the graph value_infos and diff --git a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py index 1213342004d48..fbdbac3504b65 100644 --- a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py @@ -4,7 +4,6 @@ import copy import logging import os -from typing import List, Optional, Set, Tuple, Union import onnx @@ -35,7 +34,7 @@ def disable_training_mode_batchnorm(node): ops_to_disable_training_mode_func_map[node.op_type](node) -def _reorder_outputs(model: onnx.ModelProto, user_output_names: List[str], requires_grad: Set[str]) -> None: +def _reorder_outputs(model: onnx.ModelProto, user_output_names: list[str], requires_grad: set[str]) -> None: """Reorders the outputs of the model to match the order of [user_outputs, gradients]""" graph_outputs = {output.name: output for output in model.graph.output} @@ -50,7 +49,7 @@ def _reorder_outputs(model: onnx.ModelProto, user_output_names: List[str], requi model.graph.output.extend(ordered_graph_outputs) -def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Optional[Set[str]] = None) -> None: +def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: set[str] | None = None) -> None: # Move all trainable and non trainable initializers to graph inputs. # This allows training to pass in the parameters from outside the graph # so as to share the parameters across multiple sessions. @@ -70,9 +69,9 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti def _gradient_model_for( model: onnx.ModelProto, - requires_grad: Set[str], + requires_grad: set[str], loss_name: str, - options: Optional[SessionOptions] = None, + options: SessionOptions | None = None, ) -> onnx.ModelProto: """Builds the gradient graph on top of the given input forward only graph.""" @@ -87,11 +86,11 @@ def _gradient_model_for( def build_gradient_graph( model: onnx.ModelProto, - requires_grad: Set[str], - frozen_params: Set[str], - output_names: Union[List[str], str], - custom_op_library: Optional[str] = None, -) -> Tuple[onnx.ModelProto, onnx.ModelProto]: + requires_grad: set[str], + frozen_params: set[str], + output_names: list[str] | str, + custom_op_library: str | None = None, +) -> tuple[onnx.ModelProto, onnx.ModelProto]: """Prepare the training model and the eval model. This function will restructure the model to prepare for training. @@ -134,7 +133,7 @@ def build_gradient_graph( return gradient_model, eval_model -def build_gradient_accumulation_graph(grad_model: onnx.ModelProto, requires_grad: Set[str]) -> None: +def build_gradient_accumulation_graph(grad_model: onnx.ModelProto, requires_grad: set[str]) -> None: """Builds gradient accumulation nodes on top of a training model. Adds an InPlaceAccumulatorV2 node for every gradient so that the gradients @@ -209,8 +208,8 @@ def build_gradient_accumulation_graph(grad_model: onnx.ModelProto, requires_grad def get_model_parameters( - model: onnx.ModelProto, requires_grad: Set[str], frozen_params: Set[str] -) -> Tuple[List[onnx.TensorProto], List[onnx.TensorProto]]: + model: onnx.ModelProto, requires_grad: set[str], frozen_params: set[str] +) -> tuple[list[onnx.TensorProto], list[onnx.TensorProto]]: """Returns trainable and non trainable onnx model parameters. This function pulls out the model parameters from the initializers in the graph. diff --git a/orttraining/orttraining/python/training/onnxblock/blocks.py b/orttraining/orttraining/python/training/onnxblock/blocks.py index c13843f816f16..24dc263eeb09b 100644 --- a/orttraining/orttraining/python/training/onnxblock/blocks.py +++ b/orttraining/orttraining/python/training/onnxblock/blocks.py @@ -6,7 +6,7 @@ import logging import os from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any import numpy as np import onnx @@ -402,7 +402,7 @@ def __init__(self, like: str): self._like = like - def build(self, input_name: Optional[str] = None): + def build(self, input_name: str | None = None): cloned_input = None with contextlib.suppress(LookupError): # Suppress LookupError because we want to try to get the input from the output if it's not found in the inputs @@ -428,12 +428,12 @@ def __init__( default_float: float = 0.0, default_int64: int = -1, default_string: str = "_Unused", - keys_floats: Optional[List[float]] = None, - keys_int64s: Optional[List[int]] = None, - keys_strings: Optional[List[str]] = None, - values_floats: Optional[List[float]] = None, - values_int64s: Optional[List[int]] = None, - values_strings: Optional[List[str]] = None, + keys_floats: list[float] | None = None, + keys_int64s: list[int] | None = None, + keys_strings: list[str] | None = None, + values_floats: list[float] | None = None, + values_int64s: list[int] | None = None, + values_strings: list[str] | None = None, ): super().__init__() @@ -443,8 +443,8 @@ def __init__( "default_string": default_string, } - def _add_attributes(names: List[str], values: List[Any]): - for name, value in zip(names, values): + def _add_attributes(names: list[str], values: list[Any]): + for name, value in zip(names, values, strict=False): if value is not None: self._attributes[name] = value diff --git a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py index de3453c630f9c..74292ea10a522 100644 --- a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -from typing import List, Tuple, Union import onnx @@ -11,8 +10,8 @@ def save_checkpoint( - parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], - path_to_checkpoint: Union[str, os.PathLike], + parameters: tuple[list[onnx.TensorProto], list[onnx.TensorProto]], + path_to_checkpoint: str | os.PathLike, nominal_checkpoint: bool = False, ) -> None: """Saves the parameters to the checkpoint directory path_to_checkpoint. @@ -32,7 +31,7 @@ def save_checkpoint( _save_checkpoint(trainable_params, non_trainable_params, os.fspath(path_to_checkpoint), nominal_checkpoint) -def load_checkpoint_to_model(path_to_checkpoint: Union[str, os.PathLike], model: onnx.ModelProto) -> None: +def load_checkpoint_to_model(path_to_checkpoint: str | os.PathLike, model: onnx.ModelProto) -> None: """Loads the checkpoint to an onnx inference model. Args: diff --git a/orttraining/orttraining/python/training/onnxblock/loss/loss.py b/orttraining/orttraining/python/training/onnxblock/loss/loss.py index 09429dd844187..e0624c6722519 100644 --- a/orttraining/orttraining/python/training/onnxblock/loss/loss.py +++ b/orttraining/orttraining/python/training/onnxblock/loss/loss.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import copy -from typing import Optional import onnx @@ -62,7 +61,7 @@ class CrossEntropyLoss(blocks.Block): contribute to the input gradient. """ - def __init__(self, weight=None, reduction: str = "mean", ignore_index: Optional[int] = None): + def __init__(self, weight=None, reduction: str = "mean", ignore_index: int | None = None): super().__init__() if reduction not in ["mean", "sum", "none"]: @@ -230,7 +229,7 @@ def __init__(self, reduction: str = "mean"): self._abs = blocks.Abs() self._sub = blocks.Sub() - def build(self, loss_input_name: str, target_name: Optional[str] = "target"): + def build(self, loss_input_name: str, target_name: str | None = "target"): """Adds an L1 loss subgraph on top of the base_model. Args: diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index 64f7acf4dc02c..0cb42cce9e5d5 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -3,7 +3,6 @@ import logging from abc import abstractmethod -from typing import List, Tuple import onnx @@ -139,7 +138,7 @@ def requires_grad(self, argument_name: str, value: bool = True): self._requires_grad.remove(argument_name) self._frozen_params.add(argument_name) - def parameters(self) -> Tuple[List[onnx.TensorProto], List[onnx.TensorProto]]: + def parameters(self) -> tuple[list[onnx.TensorProto], list[onnx.TensorProto]]: """Trainable as well as non-trainable (frozen) parameters of the model. Model parameters that are extracted while building the training model @@ -161,7 +160,7 @@ def parameters(self) -> Tuple[List[onnx.TensorProto], List[onnx.TensorProto]]: return self._parameters - def to_model_proto(self) -> Tuple[onnx.ModelProto, onnx.ModelProto]: + def to_model_proto(self) -> tuple[onnx.ModelProto, onnx.ModelProto]: """Returns the training and eval models. Once the gradient graph is built, the training and eval models can be retrieved diff --git a/orttraining/orttraining/python/training/onnxblock/optim/optim.py b/orttraining/orttraining/python/training/onnxblock/optim/optim.py index d14b2efefe916..a18fe7e6414e2 100644 --- a/orttraining/orttraining/python/training/onnxblock/optim/optim.py +++ b/orttraining/orttraining/python/training/onnxblock/optim/optim.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import Dict, List, Optional, Tuple import onnx @@ -66,10 +65,10 @@ def __init__(self): def _build_optimizer_node( self, - input_names: List[str], + input_names: list[str], output_name: str, node_name: str, - node_attributes: Dict, + node_attributes: dict, ) -> str: """ Build and append an optimizer node to the ONNX graph. @@ -135,10 +134,10 @@ def build( class AdamWOptimizer(_OptimizerBase): def __init__( self, - bias_correction: Optional[bool] = True, - betas: Tuple[float, float] = (0.9, 0.999), - eps: Optional[float] = 1e-6, - weight_decay: Optional[float] = 0.0, + bias_correction: bool | None = True, + betas: tuple[float, float] = (0.9, 0.999), + eps: float | None = 1e-6, + weight_decay: float | None = 0.0, ): super().__init__() @@ -242,7 +241,7 @@ def _optimizer_specific_logic( learning_rate_name: str, params_name: str, gradients_name: str, - trainable_parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], + trainable_parameters: tuple[list[onnx.TensorProto], list[onnx.TensorProto]], ) -> str: raise NotImplementedError("Subclasses must implement _optimizer_specific_logic method.") @@ -264,7 +263,7 @@ def _optimizer_specific_logic( learning_rate_name: str, params_name: str, gradients_name: str, - trainable_parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], + trainable_parameters: tuple[list[onnx.TensorProto], list[onnx.TensorProto]], ) -> str: onnx_model = self.base step_name = "step" @@ -307,7 +306,7 @@ def _optimizer_specific_logic( learning_rate_name: str, params_name: str, gradients_name: str, - trainable_parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], + trainable_parameters: tuple[list[onnx.TensorProto], list[onnx.TensorProto]], ) -> str: onnx_model = self.base updated_flag_name = self._sgd(learning_rate_name, params_name, gradients_name) diff --git a/orttraining/orttraining/python/training/optim/_ds_modifier.py b/orttraining/orttraining/python/training/optim/_ds_modifier.py index 55e2e08432137..9d8f178c1c65c 100644 --- a/orttraining/orttraining/python/training/optim/_ds_modifier.py +++ b/orttraining/orttraining/python/training/optim/_ds_modifier.py @@ -178,7 +178,7 @@ def is_model_parallel_parameter(p): #### THIS IS THE FASTER IMPLEMENTATION #### grads_for_norm = [] - for g, p in zip(gradients, params): + for g, p in zip(gradients, params, strict=False): if is_model_parallel_parameter(p) or (target.model_parallel_rank == 0): # BE NOTED: deepspeed original give a double type conversion here, not sure whether this is impacting some models. # https://github.com/microsoft/DeepSpeed/blob/9e5c0c5c3ecabb68b7e9dffac0e9b8d167e3cab8/deepspeed/runtime/zero/stage2.py#L1501 diff --git a/orttraining/orttraining/python/training/ort_triton/_cache.py b/orttraining/orttraining/python/training/ort_triton/_cache.py index b70064377abfc..294c844bb5ac5 100644 --- a/orttraining/orttraining/python/training/ort_triton/_cache.py +++ b/orttraining/orttraining/python/training/ort_triton/_cache.py @@ -12,7 +12,6 @@ import sys import tempfile from types import ModuleType -from typing import Tuple @functools.lru_cache(None) @@ -73,7 +72,7 @@ class ModuleCache: clear = staticmethod(cache.clear) @classmethod - def load(cls, key_func, mod_func, *args) -> Tuple[str, ModuleType]: + def load(cls, key_func, mod_func, *args) -> tuple[str, ModuleType]: key = key_func(*args) if key not in cls.cache: func_name, mod = mod_func(*args) diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index c6759630b2777..548b415ea990e 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -12,8 +12,6 @@ """ -from typing import Tuple - import sympy import torch from sympy.codegen.rewriting import create_expand_pow_optimization @@ -49,7 +47,7 @@ def codegen(self, node: IRNode, context: CodegenContext, code_buffer: CodeBuffer assert func is not None, f"unimplemented node: {node.__class__.__name__}" func(node, context, code_buffer, indent) - def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]: + def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> tuple[str, str]: if offset_calc.is_x_reduced(arg_name): # Scalar. return "tl.full([1], 0, tl.int32)", "" @@ -61,7 +59,7 @@ def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: offset_str = str(expand_opt(sympy_dot(parse_shape(idx_var), strides))) return offset_str, "xmask" if offset_calc.requires_x_mask else "" - def _get_reduce_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]: + def _get_reduce_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> tuple[str, str]: offset_strs = [] mask_strs = [] if not offset_calc.is_x_reduced(arg_name): @@ -93,7 +91,7 @@ def _get_reduce_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) offset_strs.append("tl.full([1, 1], 0, tl.int32)") return " + ".join(offset_strs), " & ".join(mask_strs) - def _get_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]: + def _get_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> tuple[str, str]: return ( self._get_reduce_offset_mask(offset_calc, arg_name) if offset_calc.is_reduction diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index a1c3d7d7e1d4f..420c02f4c4385 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- from abc import abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Any import sympy from onnx import GraphProto, NodeProto, TensorProto @@ -12,7 +12,7 @@ from ._sympy_utils import extract_shape_from_symbol from ._utils import get_attribute, get_reduce_info, next_power_of_2 -_SPECIAL_FLOATS: List[str] = ["inf", "-inf"] +_SPECIAL_FLOATS: list[str] = ["inf", "-inf"] class CodegenContext: @@ -20,8 +20,8 @@ class CodegenContext: record variable name mapping in term of IRnodes. """ - def __init__(self, var_map: Dict[str, str]): - self._var_map: Dict[str, str] = {**var_map} + def __init__(self, var_map: dict[str, str]): + self._var_map: dict[str, str] = {**var_map} # Get variable name by the node arg name in ONNX graph. def get_variable_name(self, name: str) -> str: @@ -36,7 +36,7 @@ def get_internal_variable_name(self, name: str) -> str: class CodeBuffer: def __init__(self): - self.buffer: List[str] = [] + self.buffer: list[str] = [] def __iadd__(self, other: str): self.buffer.append(other) @@ -59,7 +59,7 @@ class SymbolicDSU: """ def __init__(self): - self._dsu: Dict[sympy.Expr, sympy.Expr] = {} + self._dsu: dict[sympy.Expr, sympy.Expr] = {} def find(self, symbolic: sympy.Expr) -> sympy.Expr: if symbolic not in self._dsu: @@ -81,25 +81,25 @@ class TensorInfo: Represent a input/output tensor of a node. """ - def __init__(self, dtype: TensorProto.DataType, shape: List[sympy.Expr]): + def __init__(self, dtype: TensorProto.DataType, shape: list[sympy.Expr]): self._dtype: TensorProto.DataType = dtype - self._shape: List[sympy.Expr] = shape + self._shape: list[sympy.Expr] = shape @property def dtype(self) -> TensorProto.DataType: return self._dtype @property - def shape(self) -> List[sympy.Expr]: + def shape(self) -> list[sympy.Expr]: return self._shape def update_shape(self, symbolics: SymbolicDSU): self._shape = [symbolics.find(dim) if dim.is_symbol else dim for dim in self._shape] -def _infer_elementwise_shape(input_infos: List[TensorInfo], symbolics: SymbolicDSU) -> List[sympy.Expr]: +def _infer_elementwise_shape(input_infos: list[TensorInfo], symbolics: SymbolicDSU) -> list[sympy.Expr]: max_len = max([len(input_info.shape) for input_info in input_infos]) - output_shape: List[sympy.Expr] = [sympy.Integer(1)] * max_len + output_shape: list[sympy.Expr] = [sympy.Integer(1)] * max_len for input_info in input_infos: offset = max_len - len(input_info.shape) for idx, dim in enumerate(input_info.shape): @@ -112,22 +112,22 @@ def _infer_elementwise_shape(input_infos: List[TensorInfo], symbolics: SymbolicD def _infer_elementwise( - node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU -) -> List[TensorInfo]: + node: NodeProto, input_infos: list[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> list[TensorInfo]: # pylint: disable=unused-argument return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos, symbolics))] def _infer_where( - node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU -) -> List[TensorInfo]: + node: NodeProto, input_infos: list[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> list[TensorInfo]: # pylint: disable=unused-argument return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos, symbolics))] def _infer_reduction( - node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU -) -> List[TensorInfo]: + node: NodeProto, input_infos: list[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> list[TensorInfo]: # pylint: disable=unused-argument input_rank = len(input_infos[0].shape) keep_dims, axes = get_reduce_info(node, graph, input_rank) @@ -141,15 +141,15 @@ def _infer_reduction( def _infer_unary( - node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU -) -> List[TensorInfo]: + node: NodeProto, input_infos: list[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> list[TensorInfo]: # pylint: disable=unused-argument return [input_infos[0]] def _infer_cast( - node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU -) -> List[TensorInfo]: + node: NodeProto, input_infos: list[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> list[TensorInfo]: # pylint: disable=unused-argument dtype = get_attribute(node, "to", TensorProto.UNDEFINED) assert dtype != TensorProto.UNDEFINED @@ -157,8 +157,8 @@ def _infer_cast( def _infer_dropout( - node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU -) -> List[TensorInfo]: + node: NodeProto, input_infos: list[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> list[TensorInfo]: # pylint: disable=unused-argument return [input_infos[0], TensorInfo(TensorProto.BOOL, input_infos[0].shape)] @@ -190,8 +190,8 @@ class TypeAndShapeInfer: @classmethod def infer( - cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU - ) -> List[TensorInfo]: + cls, node: NodeProto, input_infos: list[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU + ) -> list[TensorInfo]: if node.op_type not in cls._INFER_FUNC_MAP: raise NotImplementedError(f"Unsupported op type: {node.op_type}") return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph, symbolics) @@ -224,7 +224,7 @@ def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool): ) ) ) - self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel_int, r_numel_int, contiguous) + self.configs: list[tuple[int, int, int]] = self._gen_autotune_configs(x_numel_int, r_numel_int, contiguous) # If there is symbolic shape, we will not tune the kernel. if not x_numel.is_number or not r_numel.is_number: self.configs = self.configs[-1:] @@ -233,13 +233,13 @@ def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool): def _num_warps(self, x: int, r: int) -> int: return min(max(x * r // 256, 2), 8) - def _gen_config(self, xnp2: int, rnp2: int, x: int, r: int) -> Tuple[int, int, int]: + def _gen_config(self, xnp2: int, rnp2: int, x: int, r: int) -> tuple[int, int, int]: x = min(x, xnp2) r = min(r, rnp2) return x, r, self._num_warps(x, r) # TODO: we need to tune more kernels to get more reasonable configs for better performance. - def _gen_autotune_configs(self, x_numel: int, r_numel: int, contiguous: bool) -> List[Tuple[int, int, int]]: + def _gen_autotune_configs(self, x_numel: int, r_numel: int, contiguous: bool) -> list[tuple[int, int, int]]: configs = [] xnp2 = next_power_of_2(x_numel) if r_numel == 1: diff --git a/orttraining/orttraining/python/training/ort_triton/_decompose.py b/orttraining/orttraining/python/training/ort_triton/_decompose.py index c1ded3975d3a6..601ab03847e72 100644 --- a/orttraining/orttraining/python/training/ort_triton/_decompose.py +++ b/orttraining/orttraining/python/training/ort_triton/_decompose.py @@ -8,8 +8,6 @@ "simple ops" can be executed in one pass """ -from typing import List - import sympy from onnx import GraphProto, NodeProto, TensorProto, helper @@ -30,7 +28,7 @@ class DecomposeDispatch: def __init__(self): self.count = 0 - def __call__(self, node: NodeProto, graph: GraphProto, **kwargs) -> List[NodeProto]: + def __call__(self, node: NodeProto, graph: GraphProto, **kwargs) -> list[NodeProto]: op_type = node.op_type if not hasattr(self, op_type): raise NotImplementedError(f"Not implemented for op type: {op_type}") diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index 23abb082c2217..f43e424493b2c 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -5,7 +5,7 @@ from abc import abstractmethod from collections import defaultdict -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any import sympy import torch @@ -22,16 +22,16 @@ class TensorArg: If it's constant (initializer or constant node), it also contains the data in numpy array. """ - def __init__(self, name: str, tensor_info: Optional[TensorInfo] = None, data: Optional[torch.Tensor] = None): + def __init__(self, name: str, tensor_info: TensorInfo | None = None, data: torch.Tensor | None = None): self._name: str = name - self._data: Optional[torch.Tensor] = data + self._data: torch.Tensor | None = data if data is not None: self._dtype: torch.dtype = data.dtype - self._shape: List[sympy.Expr] = parse_shape(list(data.shape)) + self._shape: list[sympy.Expr] = parse_shape(list(data.shape)) else: assert tensor_info is not None self._dtype: torch.dtype = to_torch_dtype(tensor_info.dtype) - self._shape: List[sympy.Expr] = tensor_info.shape + self._shape: list[sympy.Expr] = tensor_info.shape self.cross_kernels: bool = False @property @@ -43,11 +43,11 @@ def dtype(self) -> torch.dtype: return self._dtype @property - def shape(self) -> List[sympy.Expr]: + def shape(self) -> list[sympy.Expr]: return self._shape @property - def data(self) -> Optional[torch.Tensor]: + def data(self) -> torch.Tensor | None: return self._data @@ -61,18 +61,18 @@ class OffsetCalculator: If a reduce node has non-contiguous axes, need to decompose it into multiple reduce nodes before code-gen. """ - def __init__(self, target_shape: List[sympy.Expr], reduce_axes: List[int]): - self.target_shape: List[sympy.Expr] = target_shape + def __init__(self, target_shape: list[sympy.Expr], reduce_axes: list[int]): + self.target_shape: list[sympy.Expr] = target_shape self.is_reduction: bool = len(reduce_axes) > 0 self.rank = len(target_shape) self.reduce_axes = sort_reduce_axes(reduce_axes, self.rank) - self.x_dims: List[sympy.Expr] = [target_shape[dim] for dim in range(self.rank) if dim not in self.reduce_axes] + self.x_dims: list[sympy.Expr] = [target_shape[dim] for dim in range(self.rank) if dim not in self.reduce_axes] self.x_rank: int = len(self.x_dims) self.x_numel: sympy.Expr = sympy.prod(self.x_dims) if self.x_rank > 0 else sympy.Integer(1) - self.r_dims: List[sympy.Expr] = [target_shape[dim] for dim in self.reduce_axes] + self.r_dims: list[sympy.Expr] = [target_shape[dim] for dim in self.reduce_axes] self.r_rank: int = len(self.r_dims) self.r_numel: sympy.Expr = sympy.prod(self.r_dims) if self.r_rank > 0 else sympy.Integer(1) - self.x_strides: List[sympy.Expr] = [] + self.x_strides: list[sympy.Expr] = [] if self.x_rank > 0: self.x_strides.append(sympy.Integer(1)) for i in range(self.x_rank - 2, -1, -1): @@ -80,14 +80,14 @@ def __init__(self, target_shape: List[sympy.Expr], reduce_axes: List[int]): # To avoid generating useless code for offset calculation, we use x_compute_dims and r_compute_dims to # track the dimensions that need to be computed in the offset calculation. These 2 sets will be set in # register_tensor_arg function below. - self.x_compute_dims: Set[int] = set() - self.r_strides: List[sympy.Expr] = [] + self.x_compute_dims: set[int] = set() + self.r_strides: list[sympy.Expr] = [] if self.r_rank > 0: self.r_strides.append(sympy.Integer(1)) for i in range(self.r_rank - 2, -1, -1): self.r_strides.insert(0, self.r_strides[0] * self.r_dims[i + 1]) - self.r_compute_dims: Set[int] = set() - self.input_strides: Dict[str, List[sympy.Expr]] = dict() + self.r_compute_dims: set[int] = set() + self.input_strides: dict[str, list[sympy.Expr]] = dict() self.autotune_configs: AutotuneConfigs = AutotuneConfigs( self.x_numel, self.r_numel, not self.is_reduction or self.reduce_axes[-1] == self.rank - 1 ) @@ -99,17 +99,17 @@ def __init__(self, target_shape: List[sympy.Expr], reduce_axes: List[int]): self.requires_r_mask: bool = any( simplified_r_numel % sympy.Integer(config[1]) != 0 for config in self.autotune_configs.configs ) - self.reduced_args: Set[str] = set() - self.symbolic_shape_variables: Set[str] = set() + self.reduced_args: set[str] = set() + self.symbolic_shape_variables: set[str] = set() - def get_input_strides(self, name: str) -> List[sympy.Expr]: + def get_input_strides(self, name: str) -> list[sympy.Expr]: assert name in self.input_strides return self.input_strides[name] - def get_x_input_strides(self, name: str) -> List[sympy.Expr]: + def get_x_input_strides(self, name: str) -> list[sympy.Expr]: return [dim for idx, dim in enumerate(self.get_input_strides(name)) if idx not in self.reduce_axes] - def get_r_input_strides(self, name: str) -> List[sympy.Expr]: + def get_r_input_strides(self, name: str) -> list[sympy.Expr]: return [dim for idx, dim in enumerate(self.get_input_strides(name)) if idx in self.reduce_axes] # Whether the x shape of the tensor argument is contiguous and is same as the target shape. @@ -195,9 +195,9 @@ class IRNode: The base class for all IR nodes. """ - def __init__(self, inputs: List[TensorArg], outputs: List[TensorArg]): - self.inputs: List[TensorArg] = inputs - self.outputs: List[TensorArg] = outputs + def __init__(self, inputs: list[TensorArg], outputs: list[TensorArg]): + self.inputs: list[TensorArg] = inputs + self.outputs: list[TensorArg] = outputs @abstractmethod def codegen(self, visitor: NodeVisitor, context: CodegenContext, code_buffer: CodeBuffer, indent: int = 0): @@ -212,13 +212,13 @@ class ComputeNode(IRNode): def __init__( self, op_type: str, - inputs: List[TensorArg], - outputs: List[TensorArg], - attributes: Dict[str, Any] = {}, # noqa: B006 + inputs: list[TensorArg], + outputs: list[TensorArg], + attributes: dict[str, Any] = {}, # noqa: B006 ): super().__init__(inputs, outputs) self._op_type: str = op_type - self._attributes: Dict[str, Any] = attributes + self._attributes: dict[str, Any] = attributes @property def op_type(self): @@ -230,7 +230,7 @@ def attributes(self): class ReduceNode(ComputeNode): - def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg], offset_calc: OffsetCalculator): + def __init__(self, op_type: str, inputs: list[TensorArg], outputs: list[TensorArg], offset_calc: OffsetCalculator): super().__init__(op_type, inputs, outputs) assert op_type == "ReduceSum" or op_type == "ReduceMax" or op_type == "ReduceMin" self.default_value: str = ( @@ -250,9 +250,9 @@ class ReduceForLoopStart(ComputeNode): shared-memory declaration """ - def __init__(self, reduce_nodes: List[ReduceNode], offset_calc: OffsetCalculator): + def __init__(self, reduce_nodes: list[ReduceNode], offset_calc: OffsetCalculator): super().__init__("", [], []) - self.reduce_nodes: List[ReduceNode] = reduce_nodes + self.reduce_nodes: list[ReduceNode] = reduce_nodes self.offset_calc: OffsetCalculator = offset_calc @@ -261,9 +261,9 @@ class ReduceForLoopEnd(ComputeNode): shared-memory reduction """ - def __init__(self, reduce_nodes: List[ReduceNode], offset_calc: OffsetCalculator): + def __init__(self, reduce_nodes: list[ReduceNode], offset_calc: OffsetCalculator): super().__init__("", [], []) - self.reduce_nodes: List[ReduceNode] = reduce_nodes + self.reduce_nodes: list[ReduceNode] = reduce_nodes self.offset_calc: OffsetCalculator = offset_calc @@ -273,7 +273,7 @@ class DropoutNode(ComputeNode): if there are more than one dropout operators in the subgraph. """ - def __init__(self, inputs: List[TensorArg], outputs: List[TensorArg], offset_calc: OffsetCalculator): + def __init__(self, inputs: list[TensorArg], outputs: list[TensorArg], offset_calc: OffsetCalculator): super().__init__("Dropout", inputs, outputs) self.offset_calc: OffsetCalculator = offset_calc self.offset_calc.register_tensor_arg(inputs[0]) @@ -301,14 +301,14 @@ class KernelNode(IRNode): """ - def __init__(self, inputs: List[TensorArg], outputs: List[TensorArg], target_shape: List, reduce_axes: List[int]): + def __init__(self, inputs: list[TensorArg], outputs: list[TensorArg], target_shape: list, reduce_axes: list[int]): super().__init__(inputs, outputs) self.name: str = gen_unique_name("triton") - self.internal_args: Set[str] = set() - self.constants: Dict[str, TensorArg] = dict() - self.target_shape: List[sympy.Expr] = target_shape - self.sub_nodes: List[IRNode] = [] - self.var_map: Dict[str, str] = dict() + self.internal_args: set[str] = set() + self.constants: dict[str, TensorArg] = dict() + self.target_shape: list[sympy.Expr] = target_shape + self.sub_nodes: list[IRNode] = [] + self.var_map: dict[str, str] = dict() self.has_dropout: bool = False self.offset_calc: OffsetCalculator = OffsetCalculator(target_shape, reduce_axes) @@ -335,18 +335,18 @@ def gen_variable_names(self): class ElementwiseKernelNode(KernelNode): - def __init__(self, inputs: List[TensorArg], outputs: List[TensorArg], target_shape: List[sympy.Expr]): + def __init__(self, inputs: list[TensorArg], outputs: list[TensorArg], target_shape: list[sympy.Expr]): super().__init__(inputs, outputs, target_shape, []) class ReduceKernelNode(KernelNode): def __init__( self, - inputs: List[TensorArg], - outputs: List[TensorArg], - target_shape: List[sympy.Expr], - reduce_axes: List[int], - reduced_args: Set[str], + inputs: list[TensorArg], + outputs: list[TensorArg], + target_shape: list[sympy.Expr], + reduce_axes: list[int], + reduced_args: set[str], ): super().__init__(inputs, outputs, target_shape, reduce_axes) self.offset_calc.reduced_args.update(reduced_args) @@ -361,18 +361,18 @@ class ModuleNode(IRNode): def __init__( self, func_name: str, - inputs: List[TensorArg], - outputs: List[TensorArg], - constants: List[TensorArg], - cross_kernel_args: List[Tuple[TensorArg, int]], - kernels: List[KernelNode], + inputs: list[TensorArg], + outputs: list[TensorArg], + constants: list[TensorArg], + cross_kernel_args: list[tuple[TensorArg, int]], + kernels: list[KernelNode], ): super().__init__(inputs, outputs) self.func_name: str = func_name # Currently need inputs and outputs only. May need intermediate vars and constants later. - self.constants: List[TensorArg] = constants - self.kernels: List[KernelNode] = kernels - self.var_map: Dict[str, str] = dict() + self.constants: list[TensorArg] = constants + self.kernels: list[KernelNode] = kernels + self.var_map: dict[str, str] = dict() existing_names = set() for input in self.inputs: name = gen_variable_name(input.name, "in", existing_names) @@ -380,7 +380,7 @@ def __init__( for output in self.outputs: name = gen_variable_name(output.name, "out", existing_names) self.var_map[output.name] = name - self.cross_kernel_args_to_delete: Dict[int, Set[str]] = defaultdict(set) + self.cross_kernel_args_to_delete: dict[int, set[str]] = defaultdict(set) for pair in cross_kernel_args: name = gen_variable_name(pair[0].name, "buf", existing_names) self.cross_kernel_args_to_delete[pair[1]].add(name) diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py index 7253c7935a650..642f2a02ede6f 100644 --- a/orttraining/orttraining/python/training/ort_triton/_lowering.py +++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py @@ -6,7 +6,7 @@ import itertools import warnings from collections import defaultdict -from typing import Any, Dict, List, Set, Tuple +from typing import Any import sympy from onnx import NodeProto, helper @@ -37,31 +37,31 @@ class NodeGroup: """ - def __init__(self, node: NodeProto, reduce_axes: List[int], keep_dims: int, node_arg_infos: Dict[str, TensorInfo]): + def __init__(self, node: NodeProto, reduce_axes: list[int], keep_dims: int, node_arg_infos: dict[str, TensorInfo]): self._node_arg_infos = node_arg_infos - self.nodes_groups: List[Any] = [node] - self.target_shape: List[sympy.Expr] = self._get_target_shape(node) + self.nodes_groups: list[Any] = [node] + self.target_shape: list[sympy.Expr] = self._get_target_shape(node) rank = len(self.target_shape) - self.reduce_axes: List[int] = sort_reduce_axes(reduce_axes, rank) + self.reduce_axes: list[int] = sort_reduce_axes(reduce_axes, rank) x_dims = [self.target_shape[dim] for dim in range(rank) if dim not in self.reduce_axes] # x_numel is meant to hint how many rows of tensor will be processed by each kernel. # x is same as CUDA block in X direction. x_numel: sympy.Expr = sympy.prod(x_dims) if len(x_dims) > 0 else sympy.Integer(1) - r_dims: List[sympy.Expr] = [self.target_shape[dim] for dim in self.reduce_axes] + r_dims: list[sympy.Expr] = [self.target_shape[dim] for dim in self.reduce_axes] # r_numel is meant to hint how many elements in a row of tensor will be processed by each kernel. # r is a abbreviation of reduction, so, it's only used for reduction nodes. r_numel: sympy.Expr = sympy.prod(r_dims) if len(r_dims) > 0 else sympy.Integer(1) self.autotune_configs: AutotuneConfigs = AutotuneConfigs( x_numel, r_numel, len(self.reduce_axes) == 0 or self.reduce_axes[-1] == rank - 1 ) - self.reduced_args: Set[str] = set() + self.reduced_args: set[str] = set() if keep_dims != 1: self.reduced_args.add(node.output[0]) # Check if shape can be broadcasted to target_shape. # For example, [1, 3, 1, 1] can be broadcasted to [1, 3, 5, 7]. # and we support `keepdims = false``, so [1, 3, 5, 7] is compatible with [1, 3, 5]. - def _compatible_shape(self, shape: List[sympy.Expr], split_if_different: bool) -> bool: + def _compatible_shape(self, shape: list[sympy.Expr], split_if_different: bool) -> bool: if split_if_different: return shape == self.target_shape if len(shape) > len(self.target_shape): @@ -88,7 +88,7 @@ def _get_target_shape(self, node): # 2. The target shape of a group is determined by the first node in the group. # we call it dominators, and it determinate the partition strategy of X_numel/R_numel. # A group can't have multiple dominators. - def compatible(self, node: NodeProto, reduce_axes: List[int], keep_dims: int, split_if_different: bool) -> bool: + def compatible(self, node: NodeProto, reduce_axes: list[int], keep_dims: int, split_if_different: bool) -> bool: target_shape = self._get_target_shape(node) if is_reduction_node(node): # If the following nodes are all elementwise nodes on reduce output shape. @@ -105,7 +105,7 @@ def compatible(self, node: NodeProto, reduce_axes: List[int], keep_dims: int, sp # 1. Create a new group with the reduction node. # 2. Add this node to the current group. - def add_node(self, node: NodeProto, reduce_axes: List[int], keep_dims: int): + def add_node(self, node: NodeProto, reduce_axes: list[int], keep_dims: int): if is_reduction_node(node): group = NodeGroup(node, reduce_axes, keep_dims, self._node_arg_infos) self.nodes_groups.append(group) @@ -142,7 +142,7 @@ def dependent_nodes(self, keep_reduce_node: bool): return node_map, reduce_nodes # finalize the group, and return the flatten nodes - def flatten(self, sorted_nodes: List[NodeProto]) -> Tuple[List[NodeProto], List[List[int]]]: + def flatten(self, sorted_nodes: list[NodeProto]) -> tuple[list[NodeProto], list[list[int]]]: if self.autotune_configs.requires_for_loop: layers = [] group_layer = [self] @@ -193,12 +193,12 @@ class KernelIO: """ def __init__(self): - self.module_inputs: List[str] = [] - self.cross_kernel_inputs: List[str] = [] - self.constants: List[str] = [] - self.module_outputs: List[str] = [] - self.cross_kernel_outputs: List[str] = [] - self.internal_args: List[str] = [] + self.module_inputs: list[str] = [] + self.cross_kernel_inputs: list[str] = [] + self.constants: list[str] = [] + self.module_outputs: list[str] = [] + self.cross_kernel_outputs: list[str] = [] + self.internal_args: list[str] = [] class GraphLowering: @@ -217,24 +217,24 @@ class GraphLowering: def __init__(self, sorted_graph: SortedGraph): self._sorted_graph: SortedGraph = sorted_graph - self._node_arg_infos: Dict[str, TensorInfo] = sorted_graph.node_arg_infos - self._module_inputs: List[TensorArg] = [] - self._module_outputs: List[TensorArg] = [] - self._module_constants: List[TensorArg] = [] - self._module_input_names: Set[str] = set() - self._module_output_names: Set[str] = set() - self._module_constant_names: Set[str] = set() - self._tensor_args: Dict[str, TensorArg] = {} + self._node_arg_infos: dict[str, TensorInfo] = sorted_graph.node_arg_infos + self._module_inputs: list[TensorArg] = [] + self._module_outputs: list[TensorArg] = [] + self._module_constants: list[TensorArg] = [] + self._module_input_names: set[str] = set() + self._module_output_names: set[str] = set() + self._module_constant_names: set[str] = set() + self._tensor_args: dict[str, TensorArg] = {} # Extract module inputs, outputs and constants. self._extract_module_io() # Group nodes into NodeGroups, each NodeGroup represents a kernel. - self._groups: List[NodeGroup] = [] + self._groups: list[NodeGroup] = [] self._group_nodes() # Convert NodeGroups to KernelNodes. - self._kernel_nodes: List[KernelNode] = [] - self._kernel_io_list: List[KernelIO] = [] + self._kernel_nodes: list[KernelNode] = [] + self._kernel_io_list: list[KernelIO] = [] self._lower() # A module is map to a real onnx graph. @@ -256,12 +256,12 @@ def _extract_module_io(self): for arg in itertools.chain(self._module_inputs, self._module_outputs, self._module_constants) ) - def _get_reduce_info(self, node) -> Tuple[int, List[int]]: + def _get_reduce_info(self, node) -> tuple[int, list[int]]: assert is_reduction_node(node) input_rank = len(self._node_arg_infos[node.input[0]].shape) return get_reduce_info(node, self._sorted_graph.original_graph, input_rank) - def _process_node(self, node: NodeProto, precessors: Dict[str, List[NodeProto]], group: NodeGroup): + def _process_node(self, node: NodeProto, precessors: dict[str, list[NodeProto]], group: NodeGroup): dependent_nodes = set() dependent_nodes.add(node.name) for precessor in precessors[node.name]: @@ -328,7 +328,7 @@ def _group_nodes(self): self._groups.append(group_i) flag.add(i) - def _get_node_io(self, node: NodeProto) -> Tuple[List[TensorArg], List[TensorArg]]: + def _get_node_io(self, node: NodeProto) -> tuple[list[TensorArg], list[TensorArg]]: input_args = [] for input in node.input: if input in self._tensor_args: @@ -345,7 +345,7 @@ def _get_node_io(self, node: NodeProto) -> Tuple[List[TensorArg], List[TensorArg self._tensor_args[output] = output_args[-1] return input_args, output_args - def _extract_kernel_io(self, nodes: List[NodeProto]) -> KernelIO: + def _extract_kernel_io(self, nodes: list[NodeProto]) -> KernelIO: kernel_io = KernelIO() input_set = set() output_set = set() diff --git a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py index d67a1c1665200..722f05dfdf493 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py +++ b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py @@ -5,7 +5,6 @@ import copy import itertools -from typing import Dict, List, Set import onnx import sympy @@ -30,14 +29,14 @@ class SortedGraph: input_shapes: the shapes of the model inputs. Can be numeric values or symbolic values. """ - def __init__(self, model: ModelProto, input_shapes: List[List[sympy.Expr]]): + def __init__(self, model: ModelProto, input_shapes: list[list[sympy.Expr]]): self._model: ModelProto = model self._graph: GraphProto = model.graph - self._input_shapes: List[List[sympy.Expr]] = input_shapes + self._input_shapes: list[list[sympy.Expr]] = input_shapes # For elementwise graph outputs, when we group nodes to different kernels, if the target shape is different # from other nodes' target shape, even it can be broadcasted, we still need to create a new kernel for it. - self._elementwise_graph_outputs: Set[str] = set() + self._elementwise_graph_outputs: set[str] = set() graph_output_names = [output.name for output in self._graph.output] for node in self._graph.node: if is_elementwise_node(node): @@ -46,12 +45,12 @@ def __init__(self, model: ModelProto, input_shapes: List[List[sympy.Expr]]): ) # Topological sort the nodes in the graph. - self._sorted_nodes: List[NodeProto] = topological_sort( + self._sorted_nodes: list[NodeProto] = topological_sort( [input.name for input in self._graph.input] + [initializer.name for initializer in self._graph.initializer], self._graph.node, ) - self._node_arg_infos: Dict[str, TensorInfo] = {} + self._node_arg_infos: dict[str, TensorInfo] = {} for idx, input in enumerate(self._graph.input): self._node_arg_infos[input.name] = TensorInfo(input.type.tensor_type.elem_type, self._input_shapes[idx]) for initializer in self._graph.initializer: @@ -70,7 +69,7 @@ def __init__(self, model: ModelProto, input_shapes: List[List[sympy.Expr]]): initializers = {} for initializer in self._graph.initializer: initializers[initializer.name] = initializer - self._sorted_initializers: List[TensorProto] = [] + self._sorted_initializers: list[TensorProto] = [] for node in self._sorted_nodes: for input in node.input: if input in initializers: @@ -78,8 +77,8 @@ def __init__(self, model: ModelProto, input_shapes: List[List[sympy.Expr]]): initializers.pop(input) # Split nodes to constant nodes and non-constant nodes. - self._const_nodes: List[NodeProto] = [node for node in self._sorted_nodes if node.op_type == "Constant"] - self._sorted_nodes: List[NodeProto] = [node for node in self._sorted_nodes if node.op_type != "Constant"] + self._const_nodes: list[NodeProto] = [node for node in self._sorted_nodes if node.op_type == "Constant"] + self._sorted_nodes: list[NodeProto] = [node for node in self._sorted_nodes if node.op_type != "Constant"] def __str__(self): """ @@ -140,11 +139,11 @@ def __eq__(self, other): return str(self) == str(other) @property - def const_nodes(self) -> List[NodeProto]: + def const_nodes(self) -> list[NodeProto]: return self._const_nodes @property - def sorted_nodes(self) -> List[NodeProto]: + def sorted_nodes(self) -> list[NodeProto]: return self._sorted_nodes @property @@ -152,11 +151,11 @@ def original_graph(self) -> GraphProto: return self._graph @property - def node_arg_infos(self) -> Dict[str, TensorInfo]: + def node_arg_infos(self) -> dict[str, TensorInfo]: return self._node_arg_infos @property - def elementwise_graph_outputs(self) -> Set[str]: + def elementwise_graph_outputs(self) -> set[str]: return self._elementwise_graph_outputs def _decompose(self): diff --git a/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py b/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py index a4a384c021fe8..1df587fda054e 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import re -from typing import Any, List +from typing import Any import sympy @@ -15,12 +15,12 @@ def extract_shape_from_symbol(symbol: str) -> int: return int(match.group(3)) -def sympy_dot(seq1: List[sympy.Expr], seq2: List[sympy.Expr]) -> sympy.Expr: +def sympy_dot(seq1: list[sympy.Expr], seq2: list[sympy.Expr]) -> sympy.Expr: assert len(seq1) == len(seq2) - return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) + return sympy.expand(sum(a * b for a, b in zip(seq1, seq2, strict=False))) -def parse_shape(shape: List[Any]) -> List[sympy.Expr]: +def parse_shape(shape: list[Any]) -> list[sympy.Expr]: symbol_shapes = [] for dim in shape: symbol_dim = dim diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py index e39a668bd0066..3cf5cfa184861 100644 --- a/orttraining/orttraining/python/training/ort_triton/_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_utils.py @@ -6,7 +6,7 @@ import re import uuid from collections import defaultdict -from typing import Any, List, Tuple +from typing import Any import numpy as np import torch @@ -27,7 +27,7 @@ def _topological_sort_internal(node, visited, output_consumers, sorted_nodes): # Topological sort of nodes given the input names. The list of nodes contain both constant and non-constant nodes. -def topological_sort(inputs: List[str], nodes: List[NodeProto]) -> List[NodeProto]: +def topological_sort(inputs: list[str], nodes: list[NodeProto]) -> list[NodeProto]: const_nodes = [] non_const_nodes = [] for node in nodes: @@ -119,7 +119,7 @@ def may_add_brackets(name: str) -> str: return name -def sort_reduce_axes(axes: List[int], rank: int, check_contiguous: bool = True) -> List[int]: +def sort_reduce_axes(axes: list[int], rank: int, check_contiguous: bool = True) -> list[int]: axes = [axis + rank if axis < 0 else axis for axis in axes] axes.sort() if check_contiguous: @@ -129,7 +129,7 @@ def sort_reduce_axes(axes: List[int], rank: int, check_contiguous: bool = True) # Get the keep_dims attribute and reduce axes from a reduce node. -def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tuple[int, List[int]]: +def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> tuple[int, list[int]]: keep_dims = get_attribute(node, "keepdims", 1) noop_with_empty_axes = get_attribute(node, "noop_with_empty_axes", 0) axes = get_attribute(node, "axes", None) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py index 3850d988ef473..67394fe297d51 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py @@ -40,7 +40,6 @@ """ import math -from typing import List, Tuple import torch import triton @@ -1009,7 +1008,7 @@ def _make_flash_attention_nodes( # Without causal mask, without Dropout. For example, BERT model in HuggingFace. -_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ +_PATTERN_0: list[tuple[str, bool, list[tuple[int, int, int]]]] = [ ("MatMul", False, []), # 0 ("Transpose", True, [(0, 0, 0)]), # 1 ("Transpose", True, [(0, 0, 1)]), # 2 @@ -1034,7 +1033,7 @@ def _make_flash_attention_nodes( ] -def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): +def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: list[NodeProto]): # Check forward only as the backward is expected to be consistent if it's built correctly. scale_value = matcher.get_constant_value(nodes[3].input[1]) if not ( @@ -1063,7 +1062,7 @@ def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodePro # llama2+peft, k doesn't require grad. -_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ +_PATTERN_1: list[tuple[str, bool, list[tuple[int, int, int]]]] = [ ("MatMul", False, []), # 0 ("Transpose", True, [(0, 0, 1)]), # 1 ("Div", False, [(0, 0, 0)]), # 2 @@ -1087,7 +1086,7 @@ def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodePro ] -def _optimize_for_pattern_1(matcher: GraphProto, idx: int, nodes: List[NodeProto]): +def _optimize_for_pattern_1(matcher: GraphProto, idx: int, nodes: list[NodeProto]): # Check forward only as the backward is expected to be consistent if it's built correctly. scale_value = matcher.get_constant_value(nodes[2].input[1]) if not ( @@ -1138,7 +1137,7 @@ def _optimize_for_pattern_1(matcher: GraphProto, idx: int, nodes: List[NodeProto # llama2+peft, k requires grad. -_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ +_PATTERN_2: list[tuple[str, bool, list[tuple[int, int, int]]]] = [ ("MatMul", False, []), # 0 ("Transpose", True, [(0, 0, 1)]), # 1 ("Div", False, [(0, 0, 0)]), # 2 @@ -1164,7 +1163,7 @@ def _optimize_for_pattern_1(matcher: GraphProto, idx: int, nodes: List[NodeProto ] -def _aptimize_for_pattern_2(matcher: GraphProto, idx: int, nodes: List[NodeProto]): +def _aptimize_for_pattern_2(matcher: GraphProto, idx: int, nodes: list[NodeProto]): # Check forward only as the backward is expected to be consistent if it's built correctly. scale_value = matcher.get_constant_value(nodes[2].input[1]) if not ( diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py index 1a944082fa4ba..dffdac0f34553 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py @@ -6,7 +6,6 @@ import math import os from types import ModuleType -from typing import Tuple import torch @@ -310,7 +309,7 @@ def _gen_mm_key(dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans def _gen_mm_module( dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float -) -> Tuple[str, ModuleType]: +) -> tuple[str, ModuleType]: func_name = gen_unique_name("mm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) src_code = _MM_TEMPLATE.format(**kwargs) @@ -347,7 +346,7 @@ def _gen_gemm_module( trans_b: bool, alpha: float, beta: float, -) -> Tuple[str, ModuleType]: +) -> tuple[str, ModuleType]: func_name = gen_unique_name("gemm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) kwargs["stride_cm"] = stride_cm @@ -369,7 +368,7 @@ def _gen_bmm_key( def _gen_bmm_module( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float -) -> Tuple[str, ModuleType]: +) -> tuple[str, ModuleType]: func_name = gen_unique_name("bmm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) batch = max(batch_a, batch_b) diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index 14bc2779aa05b..47d220826f73e 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -9,7 +9,6 @@ import re import sys from types import ModuleType -from typing import List, Tuple, Union import onnx from onnx import ModelProto @@ -29,7 +28,7 @@ @functools.lru_cache(None) -def _gen_module_internal(sorted_graph: SortedGraph) -> Tuple[str, str, ModuleType]: +def _gen_module_internal(sorted_graph: SortedGraph) -> tuple[str, str, ModuleType]: func_name = gen_unique_name("func") src_code = codegen(func_name, sorted_graph) return func_name, src_code, PyCodeCache().load(src_code) @@ -58,7 +57,7 @@ def set_symbolic_shape_hint(cls, symbolic_shape_hint_config): cls.symbolic_shape_hint[k] = v @classmethod - def get_shape(cls, onnx_key: int, model: ModelProto, shapes: List[List[int]]) -> List[List[Union[int, str]]]: + def get_shape(cls, onnx_key: int, model: ModelProto, shapes: list[list[int]]) -> list[list[int | str]]: if onnx_key not in cls.cache: if cls.symbolic_shape_hint is not None: for i, input in enumerate(model.graph.input): @@ -90,12 +89,12 @@ def get_shape(cls, onnx_key: int, model: ModelProto, shapes: List[List[int]]) -> return cls.cache[onnx_key] -def _gen_key(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> int: +def _gen_key(onnx_key: int, model: ModelProto, shapes: list[list[int | str]]) -> int: # pylint: disable=unused-argument return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") -def _gen_module(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: +def _gen_module(onnx_key: int, model: ModelProto, shapes: list[list[int | str]]) -> tuple[str, ModuleType]: sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes]) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 3e679c994f4bb..9ac65bde82bf8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -223,7 +223,7 @@ def _default_export( assert len(args) == len(cconv), "Number of arguments does not match calling convention" # Encode inputs to torch.autograd.Function. - for i, arg, call_type in zip(range(len(args)), args, cconv): + for i, arg, call_type in zip(range(len(args)), args, cconv, strict=False): if call_type == "d": # Got a tensor variable. tensor_args.append(arg) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 004e3540c62d6..3762c8995cdb1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import os -from typing import Callable +from collections.abc import Callable import torch import torch.onnx.symbolic_helper as sym_help diff --git a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py index 047cd4c59d636..8d64caeec6051 100644 --- a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py +++ b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import Tuple import onnxruntime from onnxruntime.capi import _pybind_state as C @@ -166,7 +165,7 @@ def run_backward(self, feeds, fetches, state): def get_serialized_ortmodule_memory_stat( self, memory_optimization_config_file_path: str, recompute_probe_level: str, return_opportunity_table: bool - ) -> Tuple[str, dict]: + ) -> tuple[str, dict]: """ Get serialized memory stats for OrtModule. """ diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback.py b/orttraining/orttraining/python/training/ortmodule/_fallback.py index 6a3793cf0f1fd..24eae3c369efe 100644 --- a/orttraining/orttraining/python/training/ortmodule/_fallback.py +++ b/orttraining/orttraining/python/training/ortmodule/_fallback.py @@ -6,7 +6,6 @@ import os from enum import IntFlag from logging import Logger -from typing import Optional import torch @@ -106,7 +105,7 @@ def __init__(self, pytorch_module: torch.nn.Module, policy: _FallbackPolicy, ret self._logger = logger def handle_exception( - self, exception: Exception, log_level: _logger.LogLevel, override_policy: Optional[_FallbackPolicy] = None + self, exception: Exception, log_level: _logger.LogLevel, override_policy: _FallbackPolicy | None = None ) -> None: """Process incoming `exception` based on the selected `policy` diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index c1ff62a5faea7..25dfd9c3d43dd 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -7,7 +7,6 @@ import logging import os from abc import ABC, abstractmethod # noqa: F401 -from typing import Dict, List, Optional, Tuple import onnx import torch @@ -30,7 +29,7 @@ class _RunStateInfo: - def __init__(self, state, output_info: List[Tuple[torch.Size, torch.device, torch.dtype]]): + def __init__(self, state, output_info: list[tuple[torch.Size, torch.device, torch.dtype]]): """ :param state: State of partial run that contains intermediate tensors needed to resume the run later. :param output_info: Output info. @@ -74,7 +73,7 @@ def __init__( self._flattened_module = module self._onnx_models = _onnx_models.ONNXModels() - self._graph_transition_manager: Optional[GraphTransitionManager] = None + self._graph_transition_manager: GraphTransitionManager | None = None # Model after inference optimization and then gradient building. self._graph_builder = None @@ -341,7 +340,7 @@ def _device(self): return self._graph_transition_manager._device @_logger.TrackTime(_logger.ORTModuleInitPhase.DETECTION) - def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): + def _detect_from_inputs(self, inputs: tuple, kwargs: dict): """ Based on runtime inspection, enable conditional optimizations if applicable. @@ -381,7 +380,7 @@ def _detect_from_inputs(self, inputs: Tuple, kwargs: Dict): [f"{k}:{v:.0f}%" for k, v in self._runtime_inspector._embedding_module_to_padding_density_map.items()] ) - def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): + def _append_pull_weight_trigger_as_input(self, kwargs: dict, device: torch.device): if self._runtime_options.enable_zero_stage3_support: from ._zero_stage3_compatibility import ( STAGE3_PULL_WEIGHT_TRIGGER_NAME, diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py index 104cc0a894eed..237aafd6d2c3c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager_factory.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import Logger -from typing import Union from ._fallback import _FallbackManager from ._inference_manager import InferenceManager @@ -24,7 +23,7 @@ def __init__( self._training_manager = TrainingManager(module, debug_options, fallback_manager, logger) self._inference_manager = InferenceManager(module, debug_options, fallback_manager, logger) - def __call__(self, is_training) -> Union[InferenceManager, TrainingManager]: + def __call__(self, is_training) -> InferenceManager | TrainingManager: if is_training: return self._training_manager else: diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index bbf271e4e9b74..ba215bd86c5a3 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -11,9 +11,9 @@ import logging import os from collections import OrderedDict +from collections.abc import Mapping, Sequence from functools import partial from hashlib import md5 as hash_fn -from typing import Mapping, Sequence import onnx import torch diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 61db462ad3bb8..362f1a88ce924 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import Logger -from typing import Tuple import onnx import torch @@ -35,7 +34,7 @@ def execution_session_run_forward( onnx_model: onnx.ModelProto, device: torch.device, *inputs, - ) -> Tuple[Tuple[torch.Tensor, ...], _RunStateInfo]: + ) -> tuple[tuple[torch.Tensor, ...], _RunStateInfo]: """Runs the forward pass on `execution_session` with given `onnx_model`, `device` and `inputs` Args: diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 8ad3d0df3e4fa..f88390130b81f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -7,9 +7,9 @@ import gc import inspect from collections import OrderedDict, abc +from collections.abc import Callable, Mapping, Sequence from functools import partial from logging import Logger -from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple import torch @@ -78,7 +78,7 @@ def symbolic(g, self): def deepcopy_model_input( *args, **kwargs -) -> Tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: +) -> tuple[Sequence[ORTModelInputOutputType], Mapping[str, ORTModelInputOutputType]]: def extract_tensor(value): if isinstance(value, torch.Tensor): if value.requires_grad: @@ -101,7 +101,7 @@ def extract_tensor(value): def _extract_schema( data: ORTModelInputOutputType, device -) -> Tuple[Sequence[ORTModelInputOutputType], ORTModelInputOutputSchemaType]: +) -> tuple[Sequence[ORTModelInputOutputType], ORTModelInputOutputSchemaType]: try: flatten_data, schema = extract_data_and_schema(data, constant_as_tensor=True, device=device) return flatten_data, schema @@ -119,15 +119,15 @@ def __init__(self, original_module: torch.nn.Module): # original module's forward function. # So we need set those information that are needed to unflatten the args and kwargs, before calling the # torch.export. - self._device: Optional[torch.device] = None - self._args_schema: Optional[ORTModelInputOutputSchemaType] = None - self._kwargs_schema: Optional[ORTModelInputOutputSchemaType] = None - self._num_positionals: Optional[int] = None + self._device: torch.device | None = None + self._args_schema: ORTModelInputOutputSchemaType | None = None + self._kwargs_schema: ORTModelInputOutputSchemaType | None = None + self._num_positionals: int | None = None # Similarly, to make torch.export happy, we need to flatten the original module's outputs into a 1-D list of tensors. # Need to keep the output schema to unflatten the outputs back to the original structure. # Then those code depends on the original structure of the outputs can work properly. - self._output_schema: Optional[ORTModelInputOutputSchemaType] = None + self._output_schema: ORTModelInputOutputSchemaType | None = None def forward(self, *args): new_args = unflatten_data_using_schema(args[: self._num_positionals], self._args_schema) @@ -150,17 +150,17 @@ def forward(self, *args): class ModelInfoForExport: def __init__( self, - onnx_graph_input_names: List[str], - onnx_graph_input_names_require_grad: List[str], - onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]], - onnx_graph_input_shapes: List[List[int]], - onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = None, - onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = None, - onnx_graph_input_arg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, - onnx_graph_input_kwarg_schema: Optional[Dict[str, ORTModelInputOutputSchemaType]] = None, + onnx_graph_input_names: list[str], + onnx_graph_input_names_require_grad: list[str], + onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]], + onnx_graph_input_shapes: list[list[int]], + onnx_graph_input_data_accessor_user_defined: dict[str, callable] | None = None, + onnx_graph_input_const_as_tensor: dict[str, torch.device] | None = None, + onnx_graph_input_arg_schema: dict[str, ORTModelInputOutputSchemaType] | None = None, + onnx_graph_input_kwarg_schema: dict[str, ORTModelInputOutputSchemaType] | None = None, num_positional_args: int = 0, - export_mode: Optional[int] = None, - export_extra_kwargs: Optional[Dict[str, any]] = None, + export_mode: int | None = None, + export_extra_kwargs: dict[str, any] | None = None, ): # Value can be either torch.onnx.TrainingMode.TRAINING or torch.onnx.TrainingMode.EVAL self.export_mode = export_mode @@ -172,41 +172,41 @@ def __init__( # Input names parsed and then flatten from the model's forward function signature. # This should contains ONLY the user defined input names # Be noted: some of the input might not be used by the model for its compute. - self.onnx_graph_input_names: List[str] = onnx_graph_input_names + self.onnx_graph_input_names: list[str] = onnx_graph_input_names # A subset of onnx_graph_input_names. # Input names that require gradient parsed and then flatten from the model's forward function signature # This should contains ONLY the user defined input names # Be noted: some of the input might not be used by the model for its compute. - self.onnx_graph_input_names_require_grad: List[str] = onnx_graph_input_names_require_grad + self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} - self.onnx_graph_input_dynamic_axes_map: Dict[str, Dict[int, str]] = onnx_graph_input_dynamic_axes_map + self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] = onnx_graph_input_dynamic_axes_map - self.onnx_graph_input_shapes: List[List[int]] = onnx_graph_input_shapes + self.onnx_graph_input_shapes: list[list[int]] = onnx_graph_input_shapes # The input args schema for the original model's forward function. # Only contains the schema for those inputs used by the model for its compute (e.g. as the inputs # of the export model). - self.onnx_graph_input_arg_schema: Dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_arg_schema + self.onnx_graph_input_arg_schema: dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_arg_schema # The input kwargs schema for the original model's forward function. # Only contains the schema for those inputs used by the model for its compute (e.g. as the inputs # of the export model). - self.onnx_graph_input_kwarg_schema: Dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_kwarg_schema + self.onnx_graph_input_kwarg_schema: dict[str, ORTModelInputOutputSchemaType] = onnx_graph_input_kwarg_schema self.num_positional_args: int = num_positional_args # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. - self.onnx_graph_input_data_accessor_user_defined: Optional[Dict[str, callable]] = ( + self.onnx_graph_input_data_accessor_user_defined: dict[str, callable] | None = ( onnx_graph_input_data_accessor_user_defined ) - self.onnx_graph_input_const_as_tensor: Optional[Dict[str, torch.device]] = onnx_graph_input_const_as_tensor + self.onnx_graph_input_const_as_tensor: dict[str, torch.device] | None = onnx_graph_input_const_as_tensor def __str__(self) -> str: return f"""ModelInfoForExport class: @@ -237,14 +237,14 @@ class SkipRetValue: def parse_inputs_for_onnx_export( - all_input_parameters: List[inspect.Parameter], + all_input_parameters: list[inspect.Parameter], args: Sequence[ORTModelInputOutputType], kwargs: Mapping[str, ORTModelInputOutputType], constant_as_tensor: bool, device: torch.device, export_mode: int, logger: Logger, - export_extra_kwargs: Optional[Dict[str, any]] = None, + export_extra_kwargs: dict[str, any] | None = None, ) -> ModelInfoForExport: """Parses through the model inputs and returns _InputInfo. @@ -275,7 +275,7 @@ def parse_inputs_for_onnx_export( arg_tensor_idx = [-1] kwarg_tensor_idx = [-1] - def _add_dynamic_shape(name, input) -> Dict[str, Dict[int, str]]: + def _add_dynamic_shape(name, input) -> dict[str, dict[int, str]]: dynamic_axes[name] = {} for dim_idx in range(len(input.shape)): dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"}) @@ -285,7 +285,7 @@ def _warn_of_constant_inputs(data): logger.info(f"Received input of type {type(data)} is treated as a constant by ORT by default.") def _add_input( - name: str, input_value, onnx_graph_input_names: List[str], cur_func: Callable, tensor_idx: List[int] + name: str, input_value, onnx_graph_input_names: list[str], cur_func: Callable, tensor_idx: list[int] ): """Returns number of expanded non none inputs that _add_input processed""" @@ -396,16 +396,16 @@ def _access_func(key, cur_func, args, kwargs): raise ORTModuleIOError(f"ORTModule does not support input type {type(value)} for input {name}") - visited_input_names: List[str] = [] + visited_input_names: list[str] = [] - onnx_graph_input_names: List[str] = [] - dynamic_axes: Dict[str, Dict[int, str]] = {} - input_names_require_grad: List[str] = [] - input_shape: List[List[int]] = [] + onnx_graph_input_names: list[str] = [] + dynamic_axes: dict[str, dict[int, str]] = {} + input_names_require_grad: list[str] = [] + input_shape: list[list[int]] = [] input_arg_schema: ORTModelInputOutputSchemaType = [] input_kwarg_schema: ORTModelInputOutputSchemaType = OrderedDict() - data_accessors: Dict[str, Callable] = OrderedDict() - const_to_tensor_inputs: Dict[str, torch.device] = OrderedDict() + data_accessors: dict[str, Callable] = OrderedDict() + const_to_tensor_inputs: dict[str, torch.device] = OrderedDict() num_positional_args: int = 0 var_positional_idx = 0 @@ -511,7 +511,7 @@ def calculate_total_parameter_size_in_bytes(module: torch.nn.Module) -> int: return total_size -def can_module_be_deep_cloned(module: torch.nn.Module, device: Optional[torch.device]) -> bool: +def can_module_be_deep_cloned(module: torch.nn.Module, device: torch.device | None) -> bool: """Check if the module can be cloned If the 2 times total module parameter size >= device memory, the module cannot be cloned. @@ -568,8 +568,8 @@ def parse_outputs_for_onnx_export_and_extract_schema( sample_outputs = model_copy(*sample_args_copy, **sample_kwargs_copy) # Parse the output and extract the output_names and output_dynamic_axes to be used for onnx export - output_names: List[str] = [] - output_dynamic_axes: Dict[str, Dict[int, str]] = {} + output_names: list[str] = [] + output_dynamic_axes: dict[str, dict[int, str]] = {} for output_idx, output in enumerate(sample_outputs): output_name = f"output-{output_idx}" output_names.append(output_name) diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index 4d54e8e59fb50..00acae9061495 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -9,10 +9,10 @@ import tempfile import textwrap import time +from collections.abc import Callable from contextlib import contextmanager from enum import IntEnum from functools import partial -from typing import Callable, Dict, List, Optional from onnxruntime.capi._pybind_state import Severity @@ -28,7 +28,7 @@ class LogLevel(IntEnum): FATAL = 5 -ORTMODULE_LOG_LEVEL_MAP: Dict[LogLevel, List[int]] = { +ORTMODULE_LOG_LEVEL_MAP: dict[LogLevel, list[int]] = { LogLevel.VERBOSE: [Severity.VERBOSE, logging.DEBUG], LogLevel.DEVINFO: [Severity.INFO, logging.INFO], # ONNX Runtime has too many INFO logs, so we map it to WARNING for a better user experience. @@ -107,8 +107,8 @@ class TimeTracker: def __init__( self, ): - self.starts_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase) - self.ends_: List[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase) + self.starts_: list[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase) + self.ends_: list[float] = [TimeTracker.NOT_RECORD] * len(ORTModuleInitPhase) def start(self, phase: ORTModuleInitPhase): self.starts_[phase] = time.time() @@ -184,7 +184,7 @@ def wrapper(*args, **kwargs): @contextmanager -def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None): +def _suppress_os_stream_output(enable=True, on_exit: Callable | None = None): """Suppress output from being printed to stdout and stderr. If on_exit is not None, it will be called when the context manager exits. @@ -224,7 +224,7 @@ def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None): yield -def _log_with_filter(logger: logging.Logger, record_filters: Optional[List[str]], name: Optional[str], fo): +def _log_with_filter(logger: logging.Logger, record_filters: list[str] | None, name: str | None, fo): """Log the content by filtering with list of string patterns. Args: logger: The logger to log the content. diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index 4b6011f0786ec..3f9262bc010c2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -4,7 +4,6 @@ import os from dataclasses import dataclass -from typing import Optional import onnx import torch @@ -31,7 +30,7 @@ class ONNXModels: It has further optimizations done by the InferenceSession and is saved by the InferenceSession. """ - optimized_model: Optional[onnx.ModelProto] = None + optimized_model: onnx.ModelProto | None = None def save_optimized_model(self, path, name_prefix, export_mode): # save the ortmodule optimized model diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index c739283e5cafb..6026ecb861efa 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -7,7 +7,6 @@ import tempfile from enum import IntEnum from logging import Logger -from typing import Dict, List, Optional, Tuple, Union import onnx import torch @@ -66,8 +65,8 @@ class MemoryOptimizationSummary: def __init__(self, saving_str="", simplified_saving_expr=None, evaluated_saving=None, freq=0): self.raw_symbolic_saving_str = saving_str - self.simplified_symbolic_saving_expr: Optional[Symbol] = simplified_saving_expr - self.evaluated_saving: Union[str, int, None] = evaluated_saving + self.simplified_symbolic_saving_expr: Symbol | None = simplified_saving_expr + self.evaluated_saving: str | int | None = evaluated_saving self.freq = freq @@ -93,9 +92,9 @@ def __init__(self, m: torch.nn.Module, logger: Logger, training: bool): self._is_enabled = True # Memory optimization related. - self.cluster_id_combination_to_saving_symbolics_map: Dict[str, MemoryOptimizationSummary] = {} + self.cluster_id_combination_to_saving_symbolics_map: dict[str, MemoryOptimizationSummary] = {} ## The value is a list of symbolic dim values parsed from the first batch. - self.symbolic_dim_name_to_value_map: Dict = {} + self.symbolic_dim_name_to_value_map: dict = {} ## Used to control only the first batch is used to collect symbolic dim values. self.symbolic_dim_collecting_completed = False @@ -132,8 +131,8 @@ def enable_memory_stats_by_step(self, print_memory_stats_by_step: bool): def collect_symbolic_dim_values( self, - onnx_input_name_to_dynamic_axes_map: Dict[str, Dict[int, str]], - onnx_input_to_value_map: Dict[str, torch.Tensor], + onnx_input_name_to_dynamic_axes_map: dict[str, dict[int, str]], + onnx_input_to_value_map: dict[str, torch.Tensor], ): """Collect symbolic dim values.""" for input_name, dynamic_axes in onnx_input_name_to_dynamic_axes_map.items(): @@ -169,7 +168,7 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r memory_optimizer_config_file_path, recompute_probe_config, False ) - cluster_id_to_saving_symbol_map: Dict[str, MemoryOptimizationSummary] = {} + cluster_id_to_saving_symbol_map: dict[str, MemoryOptimizationSummary] = {} for cluster_id, memory_saving_stat in memory_optimization_saving_symbolics.items(): memory_saving_symbolic = memory_saving_stat[0] freq = memory_saving_stat[1] @@ -282,7 +281,7 @@ def _increase_step(self): def display_memory_optimization_plans( self, memory_optimizer_config_file_path, details=False - ) -> Tuple[List[str], PTable]: + ) -> tuple[list[str], PTable]: mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) if mem_plan_count > 0: @@ -386,9 +385,9 @@ def backward(ctx, grad_output: torch.Tensor): @staticmethod def infer_shape( node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes @staticmethod diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py index 897bf89c15063..2ae3c98137cbd 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py @@ -3,7 +3,8 @@ # _torch_module_interface.py from collections import OrderedDict -from typing import Callable, Iterator, Optional, Tuple, TypeVar +from collections.abc import Callable, Iterator +from typing import Optional, TypeVar import torch @@ -58,10 +59,10 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): raise NotImplementedError(f"load_state_dict is not implemented for {type(self)}.") - def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: + def register_buffer(self, name: str, tensor: torch.Tensor | None, persistent: bool = True) -> None: raise NotImplementedError(f"register_buffer is not implemented for {type(self)}.") - def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + def register_parameter(self, name: str, param: torch.nn.Parameter | None) -> None: raise NotImplementedError(f"register_parameter is not implemented for {type(self)}.") def get_parameter(self, target: str) -> torch.nn.Parameter: @@ -73,13 +74,13 @@ def get_buffer(self, target: str) -> torch.Tensor: def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: raise NotImplementedError(f"parameters is not implemented for {type(self)}.") - def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[tuple[str, torch.nn.Parameter]]: raise NotImplementedError(f"named_parameters is not implemented for {type(self)}.") def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: raise NotImplementedError(f"buffers is not implemented for {type(self)}.") - def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[tuple[str, torch.Tensor]]: raise NotImplementedError(f"named_buffers is not implemented for {type(self)}.") def _load_from_state_dict( @@ -87,7 +88,7 @@ def _load_from_state_dict( ): raise NotImplementedError(f"_load_from_state_dict is not implemented for {type(self)}.") - def named_children(self) -> Iterator[Tuple[str, T]]: + def named_children(self) -> Iterator[tuple[str, T]]: raise NotImplementedError(f"named_children is not implemented for {type(self)}.") def modules(self) -> Iterator[T]: diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py index 125590902294d..2ed346fe0bfa6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py @@ -3,8 +3,9 @@ # _torch_module_ort.py from collections import OrderedDict +from collections.abc import Callable, Iterator from logging import Logger -from typing import Callable, Iterator, Optional, Tuple, TypeVar +from typing import Optional, TypeVar import torch @@ -75,12 +76,12 @@ def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: # key names does not need to contain the _module.flattened_module._original_module prefix return self._original_module.load_state_dict(state_dict, strict=strict) - def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: + def register_buffer(self, name: str, tensor: torch.Tensor | None, persistent: bool = True) -> None: """Override original method to delegate execution to the original PyTorch user module""" self._original_module.register_buffer(name, tensor, persistent=persistent) - def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + def register_parameter(self, name: str, param: torch.nn.Parameter | None) -> None: """Override original method to delegate execution to the original PyTorch user module""" self._original_module.register_parameter(name, param) @@ -100,7 +101,7 @@ def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: yield from self._original_module.parameters(recurse=recurse) - def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[tuple[str, torch.nn.Parameter]]: """Override original method to delegate execution to the original PyTorch user module""" yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse) @@ -110,7 +111,7 @@ def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: yield from self._original_module.buffers(recurse=recurse) - def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[tuple[str, torch.Tensor]]: """Override original method to delegate execution to the original PyTorch user module""" yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse) @@ -129,7 +130,7 @@ def _load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) - def named_children(self) -> Iterator[Tuple[str, T]]: + def named_children(self) -> Iterator[tuple[str, T]]: """Override original method to delegate execution to the original PyTorch user module""" yield from self._original_module.named_children() diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py index 9f7fb1d0dcd16..2c38e98cc8657 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py @@ -3,7 +3,8 @@ # _torch_module_pytorch.py from collections import OrderedDict -from typing import Callable, Iterator, Optional, Tuple, TypeVar +from collections.abc import Callable, Iterator +from typing import Optional, TypeVar import torch @@ -38,10 +39,10 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): return self._original_module.load_state_dict(state_dict, strict=strict) - def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: + def register_buffer(self, name: str, tensor: torch.Tensor | None, persistent: bool = True) -> None: self._original_module.register_buffer(name, tensor, persistent=persistent) - def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + def register_parameter(self, name: str, param: torch.nn.Parameter | None) -> None: self._original_module.register_parameter(name, param) def get_parameter(self, target: str) -> torch.nn.Parameter: @@ -53,13 +54,13 @@ def get_buffer(self, target: str) -> torch.Tensor: def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: yield from self._original_module.parameters(recurse=recurse) - def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[tuple[str, torch.nn.Parameter]]: yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: yield from self._original_module.buffers(recurse=recurse) - def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[tuple[str, torch.Tensor]]: yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse) def _load_from_state_dict( @@ -69,7 +70,7 @@ def _load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) - def named_children(self) -> Iterator[Tuple[str, T]]: + def named_children(self) -> Iterator[tuple[str, T]]: yield from self._original_module.named_children() def modules(self) -> Iterator[T]: diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index d5d5ce672224c..b4303587e69e6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from logging import Logger -from typing import Tuple import onnx import torch @@ -48,7 +47,7 @@ def execution_session_run_forward( device: torch.device, gradient_accumulation_manager: GradientAccumulationManager, *inputs, - ) -> Tuple[Tuple[torch.Tensor, ...], _RunStateInfo]: + ) -> tuple[tuple[torch.Tensor, ...], _RunStateInfo]: """Runs the forward pass on `execution_session` with given `onnx_model`, `device` and `inputs` Args: @@ -85,7 +84,7 @@ def execution_session_run_forward( # Run and return module outputs. execution_session.run_forward(forward_inputs, forward_outputs, state, gradient_accumulation_manager.cache) - user_outputs: Tuple[torch.Tensor, ...] = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache( + user_outputs: tuple[torch.Tensor, ...] = gradient_accumulation_manager.extract_outputs_and_maybe_update_cache( forward_outputs, device ) diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 4787cb31a24fd..2e115654e4c96 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -13,7 +13,7 @@ import random import traceback import types -from typing import Callable, List, Optional, Tuple, Union +from collections.abc import Callable import numpy as np import torch @@ -63,8 +63,8 @@ def _ortvalue_from_torch_tensor(torch_tensor: torch.Tensor) -> C.OrtValue: def _ortvalues_to_torch_tensor( - ortvalues: C.OrtValueVector, device: Optional[torch.device] = None -) -> Tuple[torch.Tensor, ...]: + ortvalues: C.OrtValueVector, device: torch.device | None = None +) -> tuple[torch.Tensor, ...]: if len(ortvalues) == 0: return tuple() @@ -76,7 +76,7 @@ def _ortvalues_to_torch_tensor( if not isinstance(ortvalues, C.OrtValueVector): raise TypeError(f"ortvalues must be an instance of OrtValueVector not {type(ortvalues)!r}.") - res: List[torch.Tensor] = ortvalues.to_dlpacks(_from_dlpack) + res: list[torch.Tensor] = ortvalues.to_dlpacks(_from_dlpack) bool_indices = ortvalues.bool_tensor_indices() if len(bool_indices): # DLPack structure does not know for sure if it stores boolean @@ -127,7 +127,7 @@ def _check_same_device(device: torch.device, argument_str: str, *args): ) -def get_device_index(device: Union[str, int, torch.device]) -> int: +def get_device_index(device: str | int | torch.device) -> int: if isinstance(device, str): # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 device = torch.device(device) @@ -136,7 +136,7 @@ def get_device_index(device: Union[str, int, torch.device]) -> int: return 0 if device.index is None else device.index -def get_device_str(device: Union[str, int, torch.device]) -> str: +def get_device_str(device: str | int | torch.device) -> str: if isinstance(device, str): # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 if device.find(":") == -1: @@ -161,7 +161,7 @@ def get_device_from_module_and_inputs(module, inputs, kwargs): return device -def _get_device_from_module(module) -> Optional[torch.device]: +def _get_device_from_module(module) -> torch.device | None: """Returns the first device found in the `module`'s parameters or None Args: @@ -187,7 +187,7 @@ def _get_device_from_module(module) -> Optional[torch.device]: return device -def _get_device_from_inputs(args, kwargs) -> Optional[torch.device]: +def _get_device_from_inputs(args, kwargs) -> torch.device | None: """Returns device from first PyTorch Tensor within args or kwargs Args: diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index 7da3e18007447..75601d0c828b8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -5,7 +5,6 @@ from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple, Union import torch from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper @@ -31,8 +30,8 @@ def post_processing_enable_zero_stage3_compat( exported_model: ModelProto, - zero_stage3_named_params: Dict[str, torch.nn.parameter.Parameter], - all_param_names: List[str], + zero_stage3_named_params: dict[str, torch.nn.parameter.Parameter], + all_param_names: list[str], ) -> ModelProto: """This function is used to enable zero stage3 compatibility. @@ -62,7 +61,7 @@ def post_processing_enable_zero_stage3_compat( def _get_param_pull_trigger_name(param_name: str) -> str: return f"pull_{param_name}" - def _get_func_name(node: NodeProto) -> Optional[str]: + def _get_func_name(node: NodeProto) -> str | None: for attr in node.attribute: if attr.name == "func_name": return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s @@ -210,7 +209,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: def _create_weight_retrieval_function( - zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], + zero_stage3_named_params: dict[str, torch.nn.parameter.Parameter] | None, ) -> str: """This function is used to create a weight retrieving function using zero_stage3_named_params.""" @@ -231,9 +230,9 @@ def backward(ctx, *grad_outputs): @staticmethod def infer_shape( node: NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: param_count = len(zero_stage3_named_params.values()) tensor_output_shapes = [ tensor_input_shapes[0], @@ -258,9 +257,9 @@ def _register_symbolic_shape_infer_functions(): def _simple_pass_through_infer_shape( node: NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes register_shape_inference_function(DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, _simple_pass_through_infer_shape) @@ -268,9 +267,9 @@ def _simple_pass_through_infer_shape( def _linear_infer_shape( node: NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: # output = input.matmul(weight.t()) tensor_input_shapes[0] # input shape2 = tensor_input_shapes[1] # weight @@ -311,13 +310,13 @@ def _alias_input(node_proto_str: str): def _create_weight_retrieval_pythonop( - zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], + zero_stage3_named_params: dict[str, torch.nn.parameter.Parameter] | None, func_full_qual_name: str, input_name: str, - output_names: List[str], + output_names: list[str], pull_weight_trigger_output_dtype: int, - pull_weight_trigger_output_shape: List[int], -) -> Tuple[ValueInfoProto, NodeProto]: + pull_weight_trigger_output_shape: list[int], +) -> tuple[ValueInfoProto, NodeProto]: """This function is used to create a weight retrieving PythonOp.""" offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params) new_input = helper.make_tensor_value_info( @@ -417,7 +416,7 @@ def stage3_export_context(enable: bool, stage3_param_handle, flattened_module): from torch.onnx._internal import _beartype @_beartype.beartype - def _get_tensor_rank(x) -> Optional[int]: + def _get_tensor_rank(x) -> int | None: ### Adapted from https://github.com/pytorch/pytorch/blob/185515368bcd7d94ac06ab1634f22b747b03c6d9/torch/onnx/symbolic_helper.py#L561 # Retrieve the real rank for the stage3 weights, because stage3 weights are all (0). from typing import cast as typing_cast diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py index 897ecac148bfb..fa4c6dd04d81b 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import Callable +from collections.abc import Callable from onnx.onnx_ml_pb2 import GraphProto diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py index c1fb6e68568f5..b5e5ae45f3631 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -15,8 +15,6 @@ support if we want to try in the future. """ -from typing import List, Tuple - from onnx import GraphProto, NodeProto, TensorProto, helper from ..graph_optimizer_registry import register_graph_optimizer @@ -125,7 +123,7 @@ def _make_efficient_attention_nodes( # Without causal mask, with Dropout. For example, BERT model in HuggingFace. -_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ +_PATTERN_0: list[tuple[str, bool, list[tuple[int, int, int]]]] = [ ("MatMul", False, []), # 0 ("Transpose", True, [(0, 0, 0)]), # 1 ("Transpose", True, [(0, 0, 1)]), # 2 @@ -152,7 +150,7 @@ def _make_efficient_attention_nodes( ] -def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): +def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: list[NodeProto]): # Check forward only as the backward is expected to be consistent if it's built correctly. scale_value = matcher.get_constant_value(nodes[3].input[1]) ratio_value = matcher.get_constant_value(nodes[6].input[1]) @@ -188,7 +186,7 @@ def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodePro # Without causal mask, without Dropout. For example, BERT model and disabling attention dropout in HuggingFace. -_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ +_PATTERN_1: list[tuple[str, bool, list[tuple[int, int, int]]]] = [ ("MatMul", False, []), # 0 ("Transpose", True, [(0, 0, 0)]), # 1 ("Transpose", True, [(0, 0, 1)]), # 2 @@ -213,7 +211,7 @@ def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodePro ] -def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): +def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: list[NodeProto]): # Check forward only as the backward is expected to be consistent if it's built correctly. scale_value = matcher.get_constant_value(nodes[3].input[1]) if not ( diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py index fbd98675aebe6..9089004559923 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py @@ -4,7 +4,8 @@ # -------------------------------------------------------------------------- import itertools -from typing import Any, Dict, List, Sequence, Tuple +from collections.abc import Sequence +from typing import Any import numpy as np from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper @@ -54,8 +55,8 @@ class GraphMatcher: def __init__(self, graph: GraphProto): self._graph: GraphProto = graph - self._op_type_to_nodes: Dict[str, List[NodeProto]] = {} - self._consumer_count: Dict[str, int] = {} + self._op_type_to_nodes: dict[str, list[NodeProto]] = {} + self._consumer_count: dict[str, int] = {} for node in graph.node: if node.op_type not in self._op_type_to_nodes: self._op_type_to_nodes[node.op_type] = [] @@ -117,7 +118,7 @@ def get_type_and_shape(self, arg: str): return initializers[0].data_type, initializers[0].dims return None, None - def _match_pattern(self, node: NodeProto, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]): + def _match_pattern(self, node: NodeProto, pattern: list[tuple[str, bool, list[tuple[int, int, int]]]]): nodes = [node] for i in range(1, len(pattern)): next_op_type = pattern[i][0] @@ -140,7 +141,7 @@ def _match_pattern(self, node: NodeProto, pattern: List[Tuple[str, bool, List[Tu nodes.append(next_node) return nodes - def match_pattern(self, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]): + def match_pattern(self, pattern: list[tuple[str, bool, list[tuple[int, int, int]]]]): for node in self._op_type_to_nodes.get(pattern[0][0], []): result = self._match_pattern(node, pattern) if len(result) == len(pattern): @@ -165,9 +166,9 @@ def make_constant_node(name: str, dtype: TensorProto.DataType, dims: Sequence[in def update_graph( graph: GraphProto, - nodes_to_remove: List[NodeProto], - nodes_to_add: List[NodeProto], - new_value_infos: List[TensorProto] = [], # noqa: B006 + nodes_to_remove: list[NodeProto], + nodes_to_add: list[NodeProto], + new_value_infos: list[TensorProto] = [], # noqa: B006 ): """Update an ONNX graph by removing some nodes, and adding some new nodes and value infos.""" nodes = [node for node in graph.node if node not in nodes_to_remove] diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index b291bfb2ba03c..a7942eea5be26 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -18,7 +18,9 @@ from onnxruntime.tools import pytorch_export_contrib_ops import torch -from typing import Iterator, Optional, OrderedDict, Tuple, TypeVar, Callable +from typing import TypeVar +from collections import OrderedDict +from collections.abc import Iterator, Callable # Needed to override PyTorch methods T = TypeVar("T", bound="torch.nn.Module") @@ -35,7 +37,7 @@ class ORTModule(torch.nn.Module): debug_options (:obj:`DebugOptions`, optional): debugging options for ORTModule. """ - def __init__(self, module: torch.nn.Module, debug_options: Optional[DebugOptions] = None): + def __init__(self, module: torch.nn.Module, debug_options: DebugOptions | None = None): # NOTE: torch.nn.Modules that call setattr on their internal attributes regularly # (for example PyTorch Lightning), will trigger regular re-exports. This is # because ORTModule auto detects such setattrs on the original module and @@ -154,7 +156,7 @@ def _replicate_for_data_parallel(self): return self._torch_module._replicate_for_data_parallel() - def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: + def add_module(self, name: str, module: torch.nn.Module | None) -> None: """Raises a ORTModuleTorchModelException exception since ORTModule does not support adding modules to it""" self._torch_module.add_module(name, module) @@ -217,12 +219,12 @@ def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: return self._torch_module.load_state_dict(state_dict, strict=strict) - def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: + def register_buffer(self, name: str, tensor: torch.Tensor | None, persistent: bool = True) -> None: """Override :meth:`~torch.nn.Module.register_buffer`""" self._torch_module.register_buffer(name, tensor, persistent=persistent) - def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + def register_parameter(self, name: str, param: torch.nn.Parameter | None) -> None: """Override :meth:`~torch.nn.Module.register_parameter`""" self._torch_module.register_parameter(name, param) @@ -242,7 +244,7 @@ def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: yield from self._torch_module.parameters(recurse=recurse) - def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[tuple[str, torch.nn.Parameter]]: """Override :meth:`~torch.nn.Module.named_parameters`""" yield from self._torch_module.named_parameters(prefix=prefix, recurse=recurse) @@ -252,7 +254,7 @@ def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: yield from self._torch_module.buffers(recurse=recurse) - def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[tuple[str, torch.Tensor]]: """Override :meth:`~torch.nn.Module.named_buffers`""" yield from self._torch_module.named_buffers(prefix=prefix, recurse=recurse) @@ -266,7 +268,7 @@ def _load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) - def named_children(self) -> Iterator[Tuple[str, torch.nn.Module]]: + def named_children(self) -> Iterator[tuple[str, torch.nn.Module]]: """Override :meth:`~torch.nn.Module.named_children`""" yield from self._torch_module.named_children() diff --git a/orttraining/orttraining/python/training/utils/data/sampler.py b/orttraining/orttraining/python/training/utils/data/sampler.py index afc4d360b1582..8dfe576714609 100644 --- a/orttraining/orttraining/python/training/utils/data/sampler.py +++ b/orttraining/orttraining/python/training/utils/data/sampler.py @@ -3,7 +3,7 @@ # sampler.py import math -from typing import Callable, Iterator, Optional +from collections.abc import Callable, Iterator import numpy as np import torch @@ -106,10 +106,10 @@ def __init__( self, dataset: Dataset, complexity_fn: Callable[..., int], - world_size: Optional[int] = None, - rank: Optional[int] = None, + world_size: int | None = None, + rank: int | None = None, shuffle: bool = True, - group_size: Optional[int] = None, + group_size: int | None = None, seed: int = 0, drop_last: bool = False, random_level: float = 0, diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index d7ea3dc419114..d466faddf91bc 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -8,7 +8,6 @@ import warnings from io import TextIOWrapper from pathlib import Path -from typing import List, Optional, Tuple, Union import onnx import torch @@ -29,7 +28,7 @@ class _InspectActivation(torch.autograd.Function): def forward( ctx, activation_name: str, - module_idx: Optional[int], + module_idx: int | None, run_ctx: RuntimeStates, input_tensor: torch.Tensor, module_post_forward, @@ -89,9 +88,9 @@ def backward(ctx, grad_output: torch.Tensor): @staticmethod def infer_shape( node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes @staticmethod @@ -124,8 +123,8 @@ class StatisticsSubscriber(SubscriberBase): def __init__( self, output_dir: str, - start_step: Union[None, int] = None, - end_step: Union[None, int] = None, + start_step: None | int = None, + end_step: None | int = None, override_output_dir: bool = False, run_on_cpu: bool = False, bucket_size: int = 1024 * 1024 * 1024 // 2, diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py index 1b9a6fc91ec3c..05c58b86b993f 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_base.py @@ -5,7 +5,6 @@ import sys -from typing import Optional, Tuple import torch @@ -52,7 +51,7 @@ class SubscriberBase: With this, the overall flow can be traced as a data flow graph (DAG). """ - def __init__(self, start_step: Optional[int], end_step: Optional[int]): + def __init__(self, start_step: int | None, end_step: int | None): """ Steps in [start_step, end_step) will run the subscriber's actions, and other steps will skip. If start_step is None, 0 is given; if end_step is None, sys.maxsize is given. @@ -66,7 +65,7 @@ def pre_forward_module_apply( module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: """This function is called inside the nn.Module's pre-forward hook. Args: @@ -91,7 +90,7 @@ def pre_forward_module_apply_impl( module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: return args, kwargs def pre_forward_tensor_apply( @@ -121,7 +120,7 @@ def post_forward_module_apply( module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: """This function is called inside the nn.Module's post-forward hook. Args: @@ -146,7 +145,7 @@ def post_forward_module_apply_impl( module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: return args, outputs def post_forward_tensor_apply( @@ -179,7 +178,7 @@ def post_forward_outmost_module_apply( module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: """This function is called inside the outmost nn.Module's post-forward hook. Args: @@ -204,7 +203,7 @@ def post_forward_outmost_module_apply_impl( module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: return args, outputs def _need_skip_step(self, current_step: int) -> bool: diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index c9c06dabab4de..c41f5078b20d7 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -6,7 +6,6 @@ import inspect from contextlib import contextmanager -from typing import List, Optional, Set, Tuple, Union import onnx import torch @@ -40,7 +39,7 @@ class _IncrementStep(torch.autograd.Function): """ @staticmethod - def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: + def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: """Make sure there is the same number of `tensor` inputs and outputs. This is enforced by ORT's PythonOp's schema check. """ @@ -57,15 +56,15 @@ def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, return tuple(t.detach().requires_grad_(t.requires_grad) for t in input_tensor_list) @staticmethod - def backward(ctx, *grad_output: Tuple[Optional[torch.Tensor], ...]) -> Tuple[Optional[torch.Tensor], ...]: + def backward(ctx, *grad_output: tuple[torch.Tensor | None, ...]) -> tuple[torch.Tensor | None, ...]: return (None, *tuple(g for g in grad_output)) @staticmethod def infer_shape( node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes @staticmethod @@ -104,11 +103,11 @@ class SubscriberManager: def __init__(self): self._run_ctx = RuntimeStates() - self._subscribers: Set[SubscriberBase] = set() + self._subscribers: set[SubscriberBase] = set() self._pre_forward_hooks = [] self._post_forward_hooks = [] - def subscribe(self, module: torch.nn.Module, subscribers: List[SubscriberBase]): + def subscribe(self, module: torch.nn.Module, subscribers: list[SubscriberBase]): """ The API is called externally to register hooks that are implicitly defined by subscribers. Each time all global states will be cleaned up once called. @@ -192,7 +191,7 @@ def _post_forward_outmost_module_hook(module, module_inputs, module_outputs): module.register_forward_hook(_post_forward_outmost_module_hook) def _initialize_one_time_global_states(self, module: torch.nn.Module): - def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: List[int]): + def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: list[int]): """ Called to register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s, this function is called recursively by passing in `next_module_index` - a list of int to maintain a @@ -219,7 +218,7 @@ def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: L next_module_index = [0] _reset_recursively(module, 1, next_module_index) - def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_module_index: List[int]): + def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_module_index: list[int]): """Register hooks for every `torch.nn.Module`. Due to `Module` can contain child `Module`s, this function is called recursively by passing in `next_module_index` - a list of int to maintain a global incremental unique module id. diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index d4b9768116e92..57078222a22e7 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -7,9 +7,10 @@ import inspect import warnings from collections import OrderedDict +from collections.abc import Callable from datetime import timedelta from types import CodeType, FunctionType -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any import onnx import torch @@ -80,7 +81,7 @@ def source_rank(self) -> int: def _source_rank(self) -> int: return 0 - def result(self) -> List[torch.Tensor]: + def result(self) -> list[torch.Tensor]: return [] def synchronize(self): @@ -177,7 +178,7 @@ def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, sta @nvtx_function_decorator -def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.parameter.Parameter]: +def _get_params_for_current_module(module: torch.nn.Module) -> list[torch.nn.parameter.Parameter]: """Retrieve the parameters for this module. Logic adapted from @@ -192,7 +193,7 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par @nvtx_function_decorator -def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: +def _get_all_zero_stage3_params(module: torch.nn.Module) -> dict[str, torch.nn.parameter.Parameter]: """Retrieve all the parameters that are offloaded.""" from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -205,7 +206,7 @@ def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.p # Used to cache the map avoid repeated loop up (X us) overhead during training. -_ModuleToParametersRefs: Dict[torch.nn.Module, List[torch.nn.parameter.Parameter]] = OrderedDict() +_ModuleToParametersRefs: dict[torch.nn.Module, list[torch.nn.parameter.Parameter]] = OrderedDict() class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): @@ -295,7 +296,7 @@ def backward(ctx, *grads): # completing the full backward propagation, will not affect parameter updates. passed_in_param_grad = [ torch.zeros(shape, dtype=dtype, device=device) - for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices) + for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices, strict=False) ] zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad) @@ -306,9 +307,9 @@ def backward(ctx, *grads): @staticmethod def infer_shape( node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: input_pointer_scalars_attr_name = "input_pointer_scalars" found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name] assert len(found) == 1 @@ -414,9 +415,9 @@ def backward(ctx, *grads): @staticmethod def infer_shape( node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_input_shapes: list[list[int | str] | None], + tensor_input_dtypes: list[torch.onnx.TensorProtoDataType], + ) -> tuple[list[list[int | str] | None], list[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes @staticmethod @@ -480,7 +481,7 @@ def pre_forward_module_apply_impl( module: torch.nn.Module, args: ORTModelInputOutputType, kwargs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: """This function is a dispatcher to call DeepSpeed stage3 pre forward hooks in sequence. All hook functions can be retrieved from the function store, due to exporter only supports a list of tensors as @@ -556,7 +557,7 @@ def post_forward_module_apply_impl( module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: """This function is a dispatcher to call DeepSpeed stage3 post forward hooks in sequence. All hook functions can be retrieved from function store, due to exporter only supports a list of tensors as @@ -615,7 +616,7 @@ def post_forward_outmost_module_apply_impl( module: torch.nn.Module, args: ORTModelInputOutputType, outputs: ORTModelInputOutputType, - ) -> Tuple[ORTModelInputOutputType, ORTModelInputOutputType]: + ) -> tuple[ORTModelInputOutputType, ORTModelInputOutputType]: outputs_tensors, outputs_schema = extract_data_and_schema(outputs) _end_of_forward_hook = self._functions.get("_end_of_forward_hook") @@ -636,7 +637,7 @@ def post_forward_outmost_module_apply_impl( return args, updated_outputs @nvtx_function_decorator - def _check_all_tensor(self, tensor_list: Tuple[torch.Tensor], module: torch.nn.Module, name: str): + def _check_all_tensor(self, tensor_list: tuple[torch.Tensor], module: torch.nn.Module, name: str): if not self._enable_debug_info: return diff --git a/orttraining/orttraining/python/training/utils/ptable.py b/orttraining/orttraining/python/training/utils/ptable.py index 5e06864800666..c3e022f252e13 100644 --- a/orttraining/orttraining/python/training/utils/ptable.py +++ b/orttraining/orttraining/python/training/utils/ptable.py @@ -3,14 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import List - class Row: """A row in a PTable""" - def __init__(self, columns: List[str]) -> None: - self._columns: List[str] = columns # List of strings + def __init__(self, columns: list[str]) -> None: + self._columns: list[str] = columns # List of strings self._annotation_table = None # Optional PTable used for displaying detailed information about the feature row. def append_annotation_table(self, ptable) -> None: @@ -21,11 +19,11 @@ class PTable: """A table that can be printed to the console.""" def __init__(self, sortable=False) -> None: - self._rows: List[Row] = [] + self._rows: list[Row] = [] self._column_count = None self._sortable = sortable # allow the rows to be sorted by the first column - def add_row(self, columns: List[str]) -> Row: + def add_row(self, columns: list[str]) -> Row: """Add a row to the table. The number of columns must match the number of columns in the table.""" if self._column_count is None: self._column_count = len(columns) diff --git a/orttraining/orttraining/python/training/utils/torch_io_helper.py b/orttraining/orttraining/python/training/utils/torch_io_helper.py index a6aa390a3ed35..f0cf09d91b81e 100644 --- a/orttraining/orttraining/python/training/utils/torch_io_helper.py +++ b/orttraining/orttraining/python/training/utils/torch_io_helper.py @@ -6,7 +6,7 @@ import copy import warnings from collections import OrderedDict, abc -from typing import List, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Mapping, Sequence import torch @@ -37,16 +37,16 @@ def get_primitive_dtype(value): # Data types supported as model inputs and outputs. -ORTModelInputOutputType = Union[ - None, - str, - int, - bool, - float, - torch.Tensor, - Sequence["ORTModelInputOutputType"], - Mapping[str, "ORTModelInputOutputType"], -] +ORTModelInputOutputType = ( + str + | int + | bool + | float + | torch.Tensor + | Sequence["ORTModelInputOutputType"] + | Mapping[str, "ORTModelInputOutputType"] + | None +) class _TensorStub: @@ -57,16 +57,16 @@ class _TensorStub: def __init__( self, tensor_idx: int, - name: Optional[str] = None, - dtype: Optional[str] = None, + name: str | None = None, + dtype: str | None = None, shape=None, - shape_dims: Optional[int] = None, + shape_dims: int | None = None, ): self.tensor_idx = tensor_idx - self.name: Optional[str] = name - self.dtype: Optional[str] = dtype + self.name: str | None = name + self.dtype: str | None = dtype self.shape = shape - self.shape_dims: Optional[int] = shape_dims # r.g. rank. + self.shape_dims: int | None = shape_dims # r.g. rank. def __repr__(self) -> str: result = "_TensorStub(" @@ -108,13 +108,9 @@ def __eq__(self, other): # Data schema used to represent model's input or output. -ORTModelInputOutputSchemaType = Union[ - None, - str, - _TensorStub, - Sequence["ORTModelInputOutputSchemaType"], - Mapping[str, "ORTModelInputOutputSchemaType"], -] +ORTModelInputOutputSchemaType = ( + str | _TensorStub | Sequence["ORTModelInputOutputSchemaType"] | Mapping[str, "ORTModelInputOutputSchemaType"] | None +) def _warn_of_constant_inputs(data): @@ -126,8 +122,8 @@ def _warn_of_constant_inputs(data): @nvtx_function_decorator def extract_data_and_schema( - data: ORTModelInputOutputType, constant_as_tensor=False, device: Optional[torch.device] = None -) -> Tuple[List[torch.Tensor], ORTModelInputOutputSchemaType]: + data: ORTModelInputOutputType, constant_as_tensor=False, device: torch.device | None = None +) -> tuple[list[torch.Tensor], ORTModelInputOutputSchemaType]: """Extract the data schema by replacing every torch.Tensor value with _TensorStub, and return all tensors in a list. @@ -235,7 +231,7 @@ def _flatten_from_data(data: ORTModelInputOutputType, prefix_name: str = ""): @nvtx_function_decorator def unflatten_data_using_schema( - data: List[torch.Tensor], schema: ORTModelInputOutputSchemaType + data: list[torch.Tensor], schema: ORTModelInputOutputSchemaType ) -> ORTModelInputOutputType: """Follows the schema to generate an output that is expected by the user. @@ -280,7 +276,7 @@ def unflatten_data_using_schema( """ - def _replace_stub_with_tensor_value(data_schema: ORTModelInputOutputSchemaType, data: List[torch.Tensor]): + def _replace_stub_with_tensor_value(data_schema: ORTModelInputOutputSchemaType, data: list[torch.Tensor]): # Recursively traverse across user_output and replace all _TensorStub # with torch.Tensor values from outputs following output_idx diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py index 2b429f3fd4f3a..49c3b32fc5037 100644 --- a/orttraining/orttraining/python/training/utils/torch_type_map.py +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -4,8 +4,6 @@ # -------------------------------------------------------------------------- -from typing import Union - import torch # Mapping from pytorch scalar type to onnx scalar type. @@ -36,7 +34,7 @@ _ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()} -def pytorch_type_to_onnx_dtype(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: +def pytorch_type_to_onnx_dtype(dtype_or_scalar_type: torch.dtype | str) -> torch.onnx.TensorProtoDataType: """Converts a pytorch dtype or scalar type string to an onnx dtype. PyTorch type can be either a dtype or a scalar type string. """ diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index 3d75b3f98862e..1dd304549869d 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -95,7 +95,7 @@ def assert_gradients_match_and_reset_gradient( pt_named_params = list(pt_model.named_parameters()) assert len(ort_named_params) == len(pt_named_params) - for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params): + for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params, strict=False): ort_name, ort_param = ort_named_param pt_name, pt_param = pt_named_param @@ -180,7 +180,7 @@ def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode= def compare_tensor_list(val_list_a, val_list_b): - for val_a, val_b in zip(val_list_a, val_list_b): + for val_a, val_b in zip(val_list_a, val_list_b, strict=False): assert_values_are_close(val_a, val_b, atol=1e-7, rtol=1e-6) diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index bd36ebf545be6..759af0854145f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -162,7 +162,7 @@ def run(fun, seed: torch.Tensor): # ORT result. tensors = run(optimized_elementwise_model, seed) - for tensor, baseline_tensor in zip(tensors, baseline_tensors): + for tensor, baseline_tensor in zip(tensors, baseline_tensors, strict=False): torch.testing.assert_close(tensor, baseline_tensor) assert len(cached.keys()) == 2, ( @@ -182,7 +182,7 @@ def run(fun, seed: torch.Tensor): # ORT result. tensors = run(optimized_elementwise_model, seed) - for tensor, baseline_tensor in zip(tensors, baseline_tensors): + for tensor, baseline_tensor in zip(tensors, baseline_tensors, strict=False): torch.testing.assert_close(tensor, baseline_tensor) # 4 GraphModule's respectively for @@ -369,7 +369,7 @@ def run(model, tensor_x, tensor_y): print(f"MNIST loss: {loss} (pytorch), {loss_new} (ort).") torch.testing.assert_close(loss, loss_new, rtol=1e-2, atol=1e-5) - for grad, grad_new in zip(grads, grads_new): + for grad, grad_new in zip(grads, grads_new, strict=False): torch.testing.assert_close(grad, grad_new) # Run 5 times because ORT runs have side effects and we want to make sure diff --git a/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py b/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py index dd26448f0c596..07a9ab3a1d1cf 100644 --- a/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py +++ b/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py @@ -92,7 +92,7 @@ def test_save(self): ort_outs = ort_session.run(None, ort_inputs) onnx_output_names = [node.name for node in onnx_model.graph.output] - onnx_name_to_output = dict(zip(onnx_output_names, ort_outs)) + onnx_name_to_output = dict(zip(onnx_output_names, ort_outs, strict=False)) ort_output = onnx_name_to_output["output"] np.testing.assert_allclose(to_numpy(torch_out), ort_output, rtol=1e-03, atol=1e-05) diff --git a/orttraining/orttraining/test/python/orttraining_test_gru.py b/orttraining/orttraining/test/python/orttraining_test_gru.py index fcb7e13b1694f..0693b2ada447b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_gru.py +++ b/orttraining/orttraining/test/python/orttraining_test_gru.py @@ -666,7 +666,7 @@ def test_gru_forward(sequence_length, batch_size, input_size, hidden_size, linea outs_ort = gru.forward_ort(inputs, weights, recurrence_weights, bias, initial_hidden_state) outs_np = gru.forward_np(inputs, weights, recurrence_weights, bias, initial_hidden_state) - for ort_out, np_out in zip(outs_ort, outs_np): + for ort_out, np_out in zip(outs_ort, outs_np, strict=False): assert np.allclose(ort_out, np_out, rtol=1e-03, atol=1e-05) @@ -716,5 +716,5 @@ def test_gru_backward(sequence_length, batch_size, input_size, hidden_size, line grad_final_hidden_state, ) - for ort_out, np_out in zip(outs_ort, outs_np): + for ort_out, np_out in zip(outs_ort, outs_np, strict=False): assert np.allclose(ort_out, np_out, rtol=1e-01, atol=1e-03) diff --git a/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py index 655c9def2c66c..ff1c4dc8aad13 100644 --- a/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py +++ b/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py @@ -200,7 +200,7 @@ def call_backward(y): def call_allclose(y, y_ref): assert type(y) is type(y_ref) if isinstance(y, Iterable): - for ele, ele_ref in zip(y, y_ref): + for ele, ele_ref in zip(y, y_ref, strict=False): torch.allclose(ele, ele_ref) else: torch.allclose(y, y_ref) diff --git a/orttraining/orttraining/test/python/orttraining_test_lort.py b/orttraining/orttraining/test/python/orttraining_test_lort.py index ccd06e1a3ab62..3aca181edcfc2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_lort.py +++ b/orttraining/orttraining/test/python/orttraining_test_lort.py @@ -101,7 +101,7 @@ def run(model, device, x, y): print(f"MNIST loss: {loss} (pytorch), {loss_new} (ort).") torch.testing.assert_close(loss.to("lazy"), loss_new, rtol=1e-2, atol=1e-5) - for g, g_new in zip(grads, grads_new): + for g, g_new in zip(grads, grads_new, strict=False): torch.testing.assert_close(g.to("lazy"), g_new) for _ in range(5): diff --git a/orttraining/orttraining/test/python/orttraining_test_lstm.py b/orttraining/orttraining/test/python/orttraining_test_lstm.py index 1d75f12801fba..57fb6c4d1985b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_lstm.py +++ b/orttraining/orttraining/test/python/orttraining_test_lstm.py @@ -867,7 +867,7 @@ def test_lstm_forward(sequence_length, batch_size, input_size, hidden_size): inputs, weights, recurrence_weights, bias, initial_hidden_state, initial_cell_state, peephole_weights ) - for ort_out, np_out in zip(outs_ort, outs_np): + for ort_out, np_out in zip(outs_ort, outs_np, strict=False): assert np.allclose(ort_out, np_out, rtol=1e-03, atol=1e-05) @@ -933,5 +933,5 @@ def test_lstm_backward(sequence_length, batch_size, input_size, hidden_size): grad_final_cell_state, ) - for ort_out, np_out in zip(outs_ort, outs_np): + for ort_out, np_out in zip(outs_ort, outs_np, strict=False): assert np.allclose(ort_out, np_out, rtol=1e-03, atol=1e-05) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 275d53daec889..d8f2ae2a5bcee 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -605,7 +605,7 @@ def test_retrieve_parameters(): # Then assert not non_trainable_params - for ort_param, (pt_param_name, pt_param) in zip(trainable_params, pt_model.named_parameters()): + for ort_param, (pt_param_name, pt_param) in zip(trainable_params, pt_model.named_parameters(), strict=False): assert ort_param.name == pt_param_name assert np.allclose( np.frombuffer(ort_param.raw_data, dtype=np.float32).reshape(pt_param.shape), @@ -853,7 +853,7 @@ def mse_loss(prediction, target): ort_outs = ort_session.run([ort_output_names], ort_inputs) # assert all the gradients are close - for ort_grad, pt_param in zip(ort_outs[0], pt_model.parameters()): + for ort_grad, pt_param in zip(ort_outs[0], pt_model.parameters(), strict=False): assert np.allclose(ort_grad, _to_numpy(pt_param.grad)) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py b/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py index d59e32cde33dd..8047e4217c6f9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py @@ -1,5 +1,4 @@ import argparse -from typing import Dict, Tuple import deepspeed import torch @@ -39,14 +38,14 @@ def __init__(self, x: torch.Tensor, y: torch.Tensor): def __len__(self) -> int: return self.x.size(0) - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self.x[idx], self.y[idx] class SimpleNetPipeInput(nn.Module): """First stage of the pipeline, responsible for initial processing.""" - def __init__(self, config: Dict[str, int]): + def __init__(self, config: dict[str, int]): super().__init__() self.linear = nn.Linear(config["input_size"], config["hidden_size"]) self.activation = nn.ReLU() @@ -60,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SimpleNetPipeBlock(nn.Module): """Intermediate stage of the pipeline, can be duplicated to deepen the network.""" - def __init__(self, config: Dict[str, int]): + def __init__(self, config: dict[str, int]): super().__init__() self.linear = nn.Linear(config["hidden_size"], config["hidden_size"]) self.activation = nn.ReLU() @@ -74,7 +73,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SimpleNetPipeOutput(nn.Module): """Final stage of the pipeline, producing the output.""" - def __init__(self, config: Dict[str, int]): + def __init__(self, config: dict[str, int]): super().__init__() self.linear = nn.Linear(config["hidden_size"], config["output_size"]) @@ -83,7 +82,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def build_model(config: Dict[str, int], n: int, layer_spec: bool) -> nn.Module: +def build_model(config: dict[str, int], n: int, layer_spec: bool) -> nn.Module: """Constructs and returns the model either using LayerSpec or nn.Sequential.""" if layer_spec: print("Wrapping layers with LayerSpec") diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 912af9bc88755..7eaa7d1d9cb5d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -4166,7 +4166,7 @@ def forward( out_ort = ort_model(*y) assert len(out_pt) == len(out_ort) - for x, y in zip(out_pt, out_ort): + for x, y in zip(out_pt, out_ort, strict=False): _test_helpers.assert_values_are_close(x, y) @@ -4257,7 +4257,7 @@ def test_hf_save_pretrained(): ).to(device) model2 = ORTModule(model2) - for p1, p2 in zip(model1.parameters(), model2.parameters()): + for p1, p2 in zip(model1.parameters(), model2.parameters(), strict=False): assert p1.data.ne(p2.data).sum() == 0 @@ -5123,7 +5123,7 @@ def run_optim_step(optimizer): pt_loss = run_step(pt_model, x1) ort_loss = run_step(ort_model, x2) - for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters(), strict=False): ort_param.grad = copy.deepcopy(pt_param.grad) _test_helpers.assert_values_are_close(pt_loss, ort_loss) @@ -5133,7 +5133,7 @@ def run_optim_step(optimizer): run_optim_step(transformers_adamw_optimizer) run_optim_step(ort_fused_adam_optimizer) - for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters(), strict=False): _test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5) @@ -5173,7 +5173,7 @@ def run_optim_step(optimizer): pt_loss = run_step(pt_model, x1) ort_loss = run_step(ort_model, x2) - for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters(), strict=False): ort_param.grad = copy.deepcopy(pt_param.grad) _test_helpers.assert_values_are_close(pt_loss, ort_loss, atol=1e-4, rtol=1e-5) @@ -5185,7 +5185,7 @@ def run_optim_step(optimizer): run_optim_step(adamw_optimizer) run_optim_step(ort_fused_adam_optimizer) - for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters(), strict=False): _test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5) @@ -5506,7 +5506,7 @@ def random_state_equal(a, b): assert type(a) is type(b) if isinstance(a, tuple): assert len(a) == len(b) - return all([random_state_equal(a_i, b_i) for a_i, b_i in zip(a, b)]) + return all([random_state_equal(a_i, b_i) for a_i, b_i in zip(a, b, strict=False)]) if isinstance(a, np.ndarray): return np.array_equal(a, b) if isinstance(a, torch.Tensor): @@ -6170,7 +6170,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): run_optim_step(pt_optimizer) run_optim_step(ort_optimizer) - for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters(), strict=False): _test_helpers.assert_values_are_close(pt_param.grad, ort_param.grad, atol=1e-4, rtol=1e-5) if os.getenv("ORTMODULE_ROCM_TEST", "0") == "1": @@ -6394,7 +6394,7 @@ def run_step(model, x): pt_grads = run_step(pt_model, pt_x) ort_grads = run_step(ort_model, ort_x) - for pt_grad, ort_grad in zip(pt_grads, ort_grads): + for pt_grad, ort_grad in zip(pt_grads, ort_grads, strict=False): if use_fp16: assert torch.allclose(pt_grad, ort_grad, atol=1e-3, rtol=1e-3) else: @@ -6443,7 +6443,7 @@ def run_step(model, x): pt_grads = run_step(pt_model, pt_x) ort_grads = run_step(ort_model, ort_x) - for pt_grad, ort_grad in zip(pt_grads, ort_grads): + for pt_grad, ort_grad in zip(pt_grads, ort_grads, strict=False): assert torch.allclose(pt_grad, ort_grad) if conv_algo_search is not None: @@ -6489,7 +6489,7 @@ def run_step(model, x): pt_grads = run_step(pt_model, pt_x) ort_grads = run_step(ort_model, ort_x) - for pt_grad, ort_grad in zip(pt_grads, ort_grads): + for pt_grad, ort_grad in zip(pt_grads, ort_grads, strict=False): assert torch.allclose(pt_grad, ort_grad, atol=1e-2, rtol=1e-2) if conv_algo_search is not None: @@ -6917,7 +6917,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): ort_model2 = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="recompute")) ort_prediction2 = run_step(ort_model2, ort_input, ort_mask, ort_target) - for ort_param1, ort_param2 in zip(ort_model1.parameters(), ort_model2.parameters()): + for ort_param1, ort_param2 in zip(ort_model1.parameters(), ort_model2.parameters(), strict=False): _test_helpers.assert_values_are_close(ort_param1.grad, ort_param2.grad, atol=1e-4, rtol=1e-5) if os.getenv("ORTMODULE_ROCM_TEST", "0") == "1": diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 5764a6a81e5db..2e1c90bcac5cd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -7,7 +7,6 @@ import copy import os -from typing import Tuple import onnx import pytest @@ -264,13 +263,13 @@ def forward( ctx, input, alpha: float, - beta: Tuple[float, float], + beta: tuple[float, float], gamma: float, delta: bool, - epsilon: Tuple[bool, bool], + epsilon: tuple[bool, bool], zeta: int, - eta: Tuple[int, int], - theta: Tuple[float, float], + eta: tuple[int, int], + theta: tuple[float, float], ): ctx.save_for_backward(input) ctx.alpha = alpha @@ -296,7 +295,7 @@ def backward(ctx, grad_output): assert alpha == alpha_value assert isinstance(alpha, float) - assert all(a == b for a, b in zip(beta, beta_value)) + assert all(a == b for a, b in zip(beta, beta_value, strict=False)) assert all(isinstance(x, float) for x in beta) assert gamma == gamma_value @@ -305,16 +304,16 @@ def backward(ctx, grad_output): assert ctx.delta == delta_value assert isinstance(ctx.delta, bool) - assert all(a == b for a, b in zip(ctx.epsilon, epsilon_value)) + assert all(a == b for a, b in zip(ctx.epsilon, epsilon_value, strict=False)) assert all(isinstance(x, bool) for x in ctx.epsilon) assert ctx.zeta == zeta_value assert isinstance(ctx.zeta, int) - assert all(a == b for a, b in zip(ctx.eta, eta_value)) + assert all(a == b for a, b in zip(ctx.eta, eta_value, strict=False)) assert all(isinstance(x, int) for x in ctx.eta) - assert all(a == b for a, b in zip(ctx.theta, theta_value)) + assert all(a == b for a, b in zip(ctx.theta, theta_value, strict=False)) assert all(isinstance(x, float) for x in ctx.theta) return alpha * beta[0] * beta[1] * gamma * grad_input, None, None, None, None, None, None, None, None @@ -1651,7 +1650,7 @@ def _compare_shape(shape1, shape2): if len(shape1.dim) != len(shape2.dim): return False - for dim1, dim2 in zip(shape1.dim, shape2.dim): + for dim1, dim2 in zip(shape1.dim, shape2.dim, strict=False): if dim1.HasField("dim_value") and dim1.HasField("dim_value") and dim1.dim_value == dim2.dim_value: continue diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 07d581b576c45..54e414b36c2ba 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -30,7 +30,7 @@ def assert_gradients_match_and_reset_gradient( pt_named_params = list(pt_model.named_parameters()) self.assertEqual(len(ort_named_params), len(pt_named_params)) - for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params): + for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params, strict=False): ort_name, ort_param = ort_named_param pt_name, pt_param = pt_named_param diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index 0c381d70ca4c1..85b7180d97ff3 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -206,7 +206,7 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar if isinstance(pt_outputs, tuple): assert isinstance(ort_outputs, tuple) assert len(pt_outputs) == len(ort_outputs) - for pt_output, ort_output in zip(pt_outputs, ort_outputs): + for pt_output, ort_output in zip(pt_outputs, ort_outputs, strict=False): _test_helpers.assert_values_are_close(pt_output, _from_dlpack(ort_output), rtol=rtol, atol=atol) else: _test_helpers.assert_values_are_close(pt_outputs, _from_dlpack(ort_outputs), rtol=rtol, atol=atol) @@ -489,7 +489,7 @@ def test_dropout_op(onnx_dtype, input_shape_and_ratio): def _check_output(x, y, mask, ratio): all_count = 0 masked_count = 0 - for x_value, y_value, mask_value in zip(x, y, mask): + for x_value, y_value, mask_value in zip(x, y, mask, strict=False): if mask_value: assert abs(y_value - x_value / (1.0 - ratio)) < 0.05 else: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortvalue.py b/orttraining/orttraining/test/python/orttraining_test_ortvalue.py index 317efa0061865..327be44ed88c2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortvalue.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortvalue.py @@ -104,7 +104,7 @@ def testOrtValueVector_float32(self): vect.push_back(ortvalue._ortvalue) self.assertEqual(len(vect.bool_tensor_indices()), 0) self.assertEqual(len(vect), 2) - for i, (ov, ar) in enumerate(zip(vect, narrays)): + for i, (ov, ar) in enumerate(zip(vect, narrays, strict=False)): ovar = ov.numpy() assert_almost_equal(ar, ovar) self.assertEqual(ov.element_type(), vect.element_type_at(i)) @@ -120,7 +120,7 @@ def testOrtValueVector_bool(self): vect.push_back(ortvalue._ortvalue) self.assertEqual(vect.bool_tensor_indices(), [0, 1]) self.assertEqual(len(vect), 2) - for ov, ar in zip(vect, narrays): + for ov, ar in zip(vect, narrays, strict=False): ovar = ov.numpy() assert_almost_equal(ar, ovar) @@ -152,7 +152,7 @@ def OrtValueVectorDlPackOrtValue(self, my_to_tensor, tensor_type, device, dtype= self.assertEqual(cf, cf2) # it should be [3, 3] ptr2 = [] - for av1, v2 in zip(narrays, converted_values): + for av1, v2 in zip(narrays, converted_values, strict=False): ptr2.append(v2.data_ptr()) if hasattr(v2, "cpu"): av2 = v2.cpu().numpy() diff --git a/orttraining/orttraining/test/python/orttraining_test_sampler.py b/orttraining/orttraining/test/python/orttraining_test_sampler.py index 68f9ac5052134..0a6b54d972a46 100644 --- a/orttraining/orttraining/test/python/orttraining_test_sampler.py +++ b/orttraining/orttraining/test/python/orttraining_test_sampler.py @@ -54,7 +54,7 @@ def test_load_balancing_data_sampler_shuffles_and_balances_load(): random.shuffle(complexities) samples = [torch.FloatTensor([val]) for val in range(100)] - samples_and_complexities = list(zip(samples, complexities)) + samples_and_complexities = list(zip(samples, complexities, strict=False)) dataset = MyDataset(samples_and_complexities) def complexity_fn(sample): @@ -67,7 +67,7 @@ def complexity_fn(sample): dataset, complexity_fn=complexity_fn, world_size=2, rank=1, shuffle=True ) - for index0, index1 in zip(data_sampler0, data_sampler1): + for index0, index1 in zip(data_sampler0, data_sampler1, strict=False): assert samples_and_complexities[index0][1] == samples_and_complexities[index1][1] @@ -90,7 +90,7 @@ def complexity_fn(sample): dataset, complexity_fn=complexity_fn, world_size=1, rank=0, shuffle=False, group_size=8 ) - for index, sorted_sample in zip(data_sampler, samples_and_complexities_sorted): + for index, sorted_sample in zip(data_sampler, samples_and_complexities_sorted, strict=False): assert samples_and_complexities[index][1] == sorted_sample[1] @@ -127,7 +127,9 @@ def complexity_fn(sample): dataset, complexity_fn=complexity_fn, world_size=1, rank=0, shuffle=True, group_size=8 ) - for index, sorted_and_shuffled_sample in zip(data_sampler, samples_and_complexities_sorted_and_shuffled): + for index, sorted_and_shuffled_sample in zip( + data_sampler, samples_and_complexities_sorted_and_shuffled, strict=False + ): assert samples_and_complexities[index][1] == sorted_and_shuffled_sample[1] diff --git a/orttraining/orttraining/test/python/orttraining_test_utilities.py b/orttraining/orttraining/test/python/orttraining_test_utilities.py index faa04f327be7f..c3fc9c2d2577a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_utilities.py +++ b/orttraining/orttraining/test/python/orttraining_test_utilities.py @@ -256,7 +256,12 @@ def _recursive_compare(real, expected): if flag == 0: out, schema = extract_data_and_schema(raw_data) - assert all([torch.allclose(o, d) if isinstance(o, torch.Tensor) else o == d for o, d in zip(out, flatten_data)]) + assert all( + [ + torch.allclose(o, d) if isinstance(o, torch.Tensor) else o == d + for o, d in zip(out, flatten_data, strict=False) + ] + ) if not isinstance(raw_data, torch.Tensor): assert type(schema) is type(raw_data) @@ -276,7 +281,7 @@ def _recursive_compare(real, expected): assert all( [ torch.allclose(o, d) if isinstance(o, torch.Tensor) else o == d - for o, d in zip(out, flatten_data_constant_as_tensor) + for o, d in zip(out, flatten_data_constant_as_tensor, strict=False) ] ) diff --git a/orttraining/tools/ci_test/compare_results.py b/orttraining/tools/ci_test/compare_results.py index 0ab0a1246a421..2d4a3d31dec41 100644 --- a/orttraining/tools/ci_test/compare_results.py +++ b/orttraining/tools/ci_test/compare_results.py @@ -43,7 +43,7 @@ def _compare_results(expected_results, actual_results, field_comparisons): return False mismatch_detected = False - for row_idx, (expected_row, actual_row) in enumerate(zip(expected_results, actual_results)): + for row_idx, (expected_row, actual_row) in enumerate(zip(expected_results, actual_results, strict=False)): for field_name, comparison in field_comparisons.items(): actual, expected = actual_row[field_name], expected_row[field_name] if not comparison.fn(actual, expected): diff --git a/orttraining/tools/scripts/nv_run_pretraining.py b/orttraining/tools/scripts/nv_run_pretraining.py index 8f399263e1e65..565f5af84d4fa 100644 --- a/orttraining/tools/scripts/nv_run_pretraining.py +++ b/orttraining/tools/scripts/nv_run_pretraining.py @@ -336,7 +336,7 @@ def prepare_model_and_optimizer(args, device): optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint["optimizer"]) - for param, saved_param in zip(amp.master_params(optimizer), checkpoint["master params"]): + for param, saved_param in zip(amp.master_params(optimizer), checkpoint["master params"], strict=False): param.data.copy_(saved_param.data) if args.local_rank != -1: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 87180a242e370..9d8e44b9cab47 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1387,7 +1387,7 @@ def generate_build_tree( if not all(needed_args): raise BuildError( "iOS/MacOS framework build on MacOS canceled due to missing arguments: " - + ", ".join(val for val, cond in zip(arg_names, needed_args) if not cond) + + ", ".join(val for val, cond in zip(arg_names, needed_args, strict=False) if not cond) ) # note: this value is mainly used in framework_info.json file to specify the build osx type platform_name = "macabi" if args.macos == "Catalyst" else args.apple_sysroot diff --git a/tools/ci_build/github/apple/package_assembly_utils.py b/tools/ci_build/github/apple/package_assembly_utils.py index c6822466d73d0..829bca8c743df 100644 --- a/tools/ci_build/github/apple/package_assembly_utils.py +++ b/tools/ci_build/github/apple/package_assembly_utils.py @@ -7,7 +7,6 @@ import pathlib import re import shutil -from typing import Dict, List _script_dir = pathlib.Path(__file__).parent.resolve(strict=True) repo_root = _script_dir.parents[3] @@ -30,7 +29,7 @@ def all_variant_names(cls): def gen_file_from_template( - template_file: pathlib.Path, output_file: pathlib.Path, variable_substitutions: Dict[str, str], strict: bool = True + template_file: pathlib.Path, output_file: pathlib.Path, variable_substitutions: dict[str, str], strict: bool = True ): """ Generates a file from a template file. @@ -69,7 +68,7 @@ def replace_template_variable(match): output.write(content) -def filter_files(all_file_patterns: List[str], excluded_file_patterns: List[str]): +def filter_files(all_file_patterns: list[str], excluded_file_patterns: list[str]): """ Filters file paths based on inclusion and exclusion patterns @@ -90,7 +89,7 @@ def filter_files(all_file_patterns: List[str], excluded_file_patterns: List[str] return list(set(all_files) - set(exclude_files)) -def copy_repo_relative_to_dir(patterns: List[str], dest_dir: pathlib.Path): +def copy_repo_relative_to_dir(patterns: list[str], dest_dir: pathlib.Path): """ Copies file paths relative to the repo root to a directory. The given paths or path patterns are relative to the repo root, and the diff --git a/tools/ci_build/op_registration_utils.py b/tools/ci_build/op_registration_utils.py index 811ce424eae10..0911a16d226f8 100644 --- a/tools/ci_build/op_registration_utils.py +++ b/tools/ci_build/op_registration_utils.py @@ -8,7 +8,6 @@ import os import pathlib import sys -import typing from logger import get_logger @@ -88,12 +87,12 @@ class RegistrationProcessor: def process_registration( self, - lines: typing.List[str], + lines: list[str], domain: str, operator: str, start_version: int, - end_version: typing.Optional[int] = None, - type: typing.Optional[str] = None, + end_version: int | None = None, + type: str | None = None, ): """ Process lines that contain a kernel registration. @@ -119,7 +118,7 @@ def ok(self): return False # return False as the derived class must override to report the real status -def _process_lines(lines: typing.List[str], offset: int, registration_processor: RegistrationProcessor): +def _process_lines(lines: list[str], offset: int, registration_processor: RegistrationProcessor): """ Process one or more lines that contain a kernel registration. Merge lines if split over multiple, and call registration_processor.process_registration with the original lines @@ -236,9 +235,7 @@ def _process_lines(lines: typing.List[str], offset: int, registration_processor: return offset + 1 -def process_kernel_registration_file( - filename: typing.Union[str, pathlib.Path], registration_processor: RegistrationProcessor -): +def process_kernel_registration_file(filename: str | pathlib.Path, registration_processor: RegistrationProcessor): """ Process a kernel registration file using registration_processor. :param filename: Path to file containing kernel registrations. diff --git a/tools/ci_build/op_registration_validator.py b/tools/ci_build/op_registration_validator.py index d92050a31f967..b64e4323e8541 100644 --- a/tools/ci_build/op_registration_validator.py +++ b/tools/ci_build/op_registration_validator.py @@ -37,8 +37,8 @@ class RegistrationInfo: domain: str operator: str start_version: int - end_version: typing.Optional[int] - lines: typing.List[str] + end_version: int | None + lines: list[str] def domain_and_op_str(self): return f"{self.domain}:{self.operator}" @@ -50,16 +50,16 @@ def _log_registration_error(r: RegistrationInfo, message: str): class RegistrationValidator(op_registration_utils.RegistrationProcessor): def __init__(self): - self.all_registrations: typing.List[RegistrationInfo] = [] + self.all_registrations: list[RegistrationInfo] = [] def process_registration( self, - lines: typing.List[str], + lines: list[str], domain: str, operator: str, start_version: int, - end_version: typing.Optional[int] = None, - type: typing.Optional[str] = None, + end_version: int | None = None, + type: str | None = None, ): self.all_registrations.append( RegistrationInfo( @@ -114,7 +114,7 @@ def _validate_registrations_for_domain_and_op(self, registrations: typing.Iterat return num_invalid_registrations - def _validate_registration(self, r: RegistrationInfo, next_r: typing.Optional[RegistrationInfo]) -> bool: + def _validate_registration(self, r: RegistrationInfo, next_r: RegistrationInfo | None) -> bool: """ Validates a registration, `r`, with the next one in sorted order for a single domain and op, `next_r`, and returns whether it is valid. diff --git a/tools/ci_build/reduce_op_kernels.py b/tools/ci_build/reduce_op_kernels.py index df6bbf7a4058e..ac26abc5a9d55 100755 --- a/tools/ci_build/reduce_op_kernels.py +++ b/tools/ci_build/reduce_op_kernels.py @@ -28,7 +28,7 @@ def _adapt_filters_for_extended_minimal_build( - base_required_ops: typing.Optional[dict], base_op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface] + base_required_ops: dict | None, base_op_type_impl_filter: OpTypeImplFilterInterface | None ): """ Adapts the values returned by parse_config() for an extended minimal build or higher. @@ -77,7 +77,7 @@ class _AdaptedFilter(OpTypeImplFilterInterface): def __init__( self, filter_to_adapt: OpTypeImplFilterInterface, - required_domain_and_optypes: typing.Set[typing.Tuple[str, str]], + required_domain_and_optypes: set[tuple[str, str]], ): self.filter_to_adapt = filter_to_adapt self.required_domain_and_optypes = required_domain_and_optypes @@ -107,17 +107,15 @@ class _ExcludingRegistrationProcessor(op_registration_utils.RegistrationProcesso def __init__( self, - required_ops: typing.Optional[dict], - op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface], + required_ops: dict | None, + op_type_impl_filter: OpTypeImplFilterInterface | None, output_file: io.TextIOWrapper, ): self._required_ops = required_ops self._op_type_impl_filter = op_type_impl_filter self._output_file = output_file - def _is_op_required( - self, domain: str, operator: str, start_version: int, end_version: typing.Optional[int] - ) -> bool: + def _is_op_required(self, domain: str, operator: str, start_version: int, end_version: int | None) -> bool: """See if an op is required.""" if self._required_ops is None: return True @@ -134,12 +132,12 @@ def _is_op_required( def process_registration( self, - lines: typing.List[str], + lines: list[str], constant_for_domain: str, operator: str, start_version: int, - end_version: typing.Optional[int] = None, - type: typing.Optional[str] = None, + end_version: int | None = None, + type: str | None = None, ): registration_identifier = "{}:{}({}){}".format( constant_for_domain, operator, start_version, f"<{type}>" if type else "" @@ -202,8 +200,8 @@ def _generate_provider_registrations( ort_root: Path, build_dir: Path, use_cuda: bool, - required_ops: typing.Optional[dict], - op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface], + required_ops: dict | None, + op_type_impl_filter: OpTypeImplFilterInterface | None, ): """Generate provider registration files.""" kernel_registration_files = [ diff --git a/tools/python/find_optimizer_opset_version_updates_required.py b/tools/python/find_optimizer_opset_version_updates_required.py index b46f7e4a54d9c..3c7d94b8ba038 100644 --- a/tools/python/find_optimizer_opset_version_updates_required.py +++ b/tools/python/find_optimizer_opset_version_updates_required.py @@ -7,7 +7,6 @@ import logging import os import re -import typing logging.basicConfig(format="[%(levelname)s] - %(message)s", level=logging.DEBUG) log = logging.getLogger() @@ -30,7 +29,7 @@ def parse_args(): return args -def get_call_args_from_file(filename: str, function_or_declaration: str) -> typing.List[str]: +def get_call_args_from_file(filename: str, function_or_declaration: str) -> list[str]: """ Search a file for all function calls or declarations that match the provided name. Requires both the opening '(' and closing ')' to be on the same line. @@ -63,7 +62,7 @@ def get_call_args_from_file(filename: str, function_or_declaration: str) -> typi return results -def get_multiline_call_args_from_file(filename: str, function_or_declaration: str) -> typing.List[str]: +def get_multiline_call_args_from_file(filename: str, function_or_declaration: str) -> list[str]: """ Search a file for all function calls or declarations that match the provided name. Allows the opening '(' and closing ')' to be split across multiple lines. @@ -96,7 +95,7 @@ def get_multiline_call_args_from_file(filename: str, function_or_declaration: st return results -def _add_if_newer(domain: str, op: str, opset: int, op_to_opset: typing.Dict[str, int]): +def _add_if_newer(domain: str, op: str, opset: int, op_to_opset: dict[str, int]): key = domain + "." + op if key not in op_to_opset or op_to_opset[key] < opset: op_to_opset[key] = opset diff --git a/tools/python/gen_contrib_doc.py b/tools/python/gen_contrib_doc.py index ce6f0a1205fdc..54bc920ea3ddd 100644 --- a/tools/python/gen_contrib_doc.py +++ b/tools/python/gen_contrib_doc.py @@ -8,7 +8,7 @@ import pathlib import sys from collections import defaultdict -from typing import Any, Dict, List, Sequence, Set, Text, Tuple # noqa: F401 +from collections.abc import Sequence # noqa: F401 import numpy as np # type: ignore from onnx import AttributeProto, FunctionProto # noqa: F401 diff --git a/tools/python/onnx2tfevents.py b/tools/python/onnx2tfevents.py index 9dfde13090b07..909bc04817ff1 100644 --- a/tools/python/onnx2tfevents.py +++ b/tools/python/onnx2tfevents.py @@ -13,7 +13,7 @@ import inspect import itertools from abc import ABC, abstractmethod -from typing import Callable, List +from collections.abc import Callable import numpy as np import onnx @@ -203,7 +203,7 @@ def _add_sections(self, name: str) -> None: if len(sec) > 0: self.sections.add(sec) - def _get_sections(self, curr_name: str, sections: List[str]) -> None: + def _get_sections(self, curr_name: str, sections: list[str]) -> None: for section in self.sections: if curr_name.startswith(section) and (len(curr_name) == len(section) or curr_name[len(section)] == "."): sections.append(section) @@ -217,8 +217,7 @@ def _transform_name(self, name: str) -> str: if "/" in name: if name.startswith(f"/{self.original_module_name}/"): name = name[len(self.original_module_name) + 2 :] - if name.startswith("/"): - name = name[1:] + name = name.removeprefix("/") return name sections = [] diff --git a/tools/python/ort_test_dir_utils.py b/tools/python/ort_test_dir_utils.py index 3af407b2aeee6..59bb6670c8794 100644 --- a/tools/python/ort_test_dir_utils.py +++ b/tools/python/ort_test_dir_utils.py @@ -159,7 +159,7 @@ def save_data(prefix, name_data_map, model_info): sess = ort.InferenceSession(test_model_filename, so) outputs = sess.run(output_names, name_input_map) name_output_map = {} - for name, data in zip(output_names, outputs): + for name, data in zip(output_names, outputs, strict=False): name_output_map[name] = data save_data("output", name_output_map, model_outputs) diff --git a/tools/python/run_CIs_for_branch.py b/tools/python/run_CIs_for_branch.py index 975ea2b988d75..d1e23f28acdf4 100644 --- a/tools/python/run_CIs_for_branch.py +++ b/tools/python/run_CIs_for_branch.py @@ -7,7 +7,6 @@ import os import subprocess import sys -import typing from run_CIs_for_external_pr import get_pipeline_names from util.platform_helpers import is_windows @@ -78,7 +77,7 @@ def _parse_args(): return args -def _run_az_pipelines_command(command: typing.List[str]): +def _run_az_pipelines_command(command: list[str]): try: az = "az.cmd" if is_windows() else "az" az_output = subprocess.run([az, "pipelines", *command], capture_output=True, text=True, check=True) diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index 228c8016170d9..faaeb4e5f7127 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -7,7 +7,6 @@ import os import subprocess import sys -import typing def get_pipeline_names(): @@ -72,7 +71,7 @@ def _parse_args(): return args -def run_gh_pr_command(command: typing.List[str], check: bool = True): +def run_gh_pr_command(command: list[str], check: bool = True): try: return subprocess.run(["gh", "pr", *command], capture_output=True, text=True, check=check) except subprocess.CalledProcessError as cpe: diff --git a/tools/python/run_adb.py b/tools/python/run_adb.py index 7506a8699df05..1928966b56a26 100755 --- a/tools/python/run_adb.py +++ b/tools/python/run_adb.py @@ -5,13 +5,12 @@ import logging import os import sys -import typing from util import run from util.android import get_sdk_tool_paths -def run_adb(android_sdk_root: str, args: typing.List[str]): +def run_adb(android_sdk_root: str, args: list[str]): sdk_tool_paths = get_sdk_tool_paths(android_sdk_root) run(sdk_tool_paths.adb, *args) diff --git a/tools/python/sparsify_initializers.py b/tools/python/sparsify_initializers.py index 2c80b07cd0a12..9232210b5fab2 100644 --- a/tools/python/sparsify_initializers.py +++ b/tools/python/sparsify_initializers.py @@ -9,11 +9,10 @@ import argparse import logging import sys -from typing import List, Tuple # noqa: F401 import numpy as np import onnx -from onnx import ModelProto, SparseTensorProto, TensorProto, numpy_helper # noqa: F401 +from onnx import ModelProto, TensorProto, numpy_helper logger = logging.getLogger(__name__) diff --git a/tools/python/util/android/android.py b/tools/python/util/android/android.py index 24004d6be761d..13f2b6b8d4952 100644 --- a/tools/python/util/android/android.py +++ b/tools/python/util/android/android.py @@ -108,7 +108,7 @@ def _stop_process_with_pid(pid: int): def start_emulator( sdk_tool_paths: SdkToolPaths, avd_name: str, - extra_args: typing.Optional[typing.Sequence[str]] = None, + extra_args: typing.Sequence[str] | None = None, timeout_minutes: int = 20, ) -> subprocess.Popen: if check_emulator_running_using_avd_name(avd_name=avd_name): @@ -326,7 +326,7 @@ def stop_emulator_by_pid(emulator_pid: int, timeout_seconds: int = 120): _log.info("Emulator stopped successfully.") -def stop_emulator(emulator_proc_or_pid: typing.Union[subprocess.Popen, int], timeout_seconds: int = 120): +def stop_emulator(emulator_proc_or_pid: subprocess.Popen | int, timeout_seconds: int = 120): """ Stops the emulator process, checking its running status before and after stopping. :param emulator_proc_or_pid: The emulator process (subprocess.Popen) or PID (int). diff --git a/tools/python/util/file_utils.py b/tools/python/util/file_utils.py index 0373ac171144f..20cc15580c9a8 100644 --- a/tools/python/util/file_utils.py +++ b/tools/python/util/file_utils.py @@ -6,7 +6,7 @@ import typing -def path_match_suffix_ignore_case(path: typing.Union[pathlib.Path, str], suffix: str) -> bool: +def path_match_suffix_ignore_case(path: pathlib.Path | str, suffix: str) -> bool: """ Returns whether `path` ends in `suffix`, ignoring case. """ @@ -16,8 +16,8 @@ def path_match_suffix_ignore_case(path: typing.Union[pathlib.Path, str], suffix: def files_from_file_or_dir( - file_or_dir_path: typing.Union[pathlib.Path, str], predicate: typing.Callable[[pathlib.Path], bool] = lambda _: True -) -> typing.List[pathlib.Path]: + file_or_dir_path: pathlib.Path | str, predicate: typing.Callable[[pathlib.Path], bool] = lambda _: True +) -> list[pathlib.Path]: """ Gets the files in `file_or_dir_path` satisfying `predicate`. If `file_or_dir_path` is a file, the single file is considered. Otherwise, all files in the directory are diff --git a/tools/python/util/onnx_model_utils.py b/tools/python/util/onnx_model_utils.py index 1938a2411e11d..ffefae58471ad 100644 --- a/tools/python/util/onnx_model_utils.py +++ b/tools/python/util/onnx_model_utils.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional import onnx from onnx import version_converter @@ -62,8 +61,8 @@ def get_opsets_imported(model: onnx.ModelProto): def update_onnx_opset( model_path: pathlib.Path, opset: int, - out_path: Optional[pathlib.Path] = None, - logger: Optional[logging.Logger] = None, + out_path: pathlib.Path | None = None, + logger: logging.Logger | None = None, ): """ Helper to update the opset of a model using onnx version_converter. Target opset must be greater than current opset. diff --git a/tools/python/util/ort_format_model/operator_type_usage_processors.py b/tools/python/util/ort_format_model/operator_type_usage_processors.py index 598549c42b60a..0e731fd421b0c 100644 --- a/tools/python/util/ort_format_model/operator_type_usage_processors.py +++ b/tools/python/util/ort_format_model/operator_type_usage_processors.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import json -import typing from abc import ABC, abstractmethod import ort_flatbuffers_py.fbs as fbs @@ -65,9 +64,7 @@ def __init__(self, domain: str, optype: str): def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): pass - def is_typed_registration_needed( - self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] - ): + def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None): """ Given the string from a kernel registration, determine if the registration is required or not. :param type_in_registration: Type string from kernel registration @@ -113,8 +110,8 @@ def __init__( optype: str, inputs: [int] = [0], # noqa: B006 outputs: [int] = [], # noqa: B006 - required_input_types: typing.Dict[int, typing.Set[str]] = {}, # noqa: B006 - required_output_types: typing.Dict[int, typing.Set[str]] = {}, # noqa: B006 + required_input_types: dict[int, set[str]] = {}, # noqa: B006 + required_output_types: dict[int, set[str]] = {}, # noqa: B006 ): """ Create DefaultTypeUsageProcessor. Types for one or more inputs and/or outputs can be tracked by the processor. @@ -186,9 +183,7 @@ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): type_str = value_name_to_typestr(node.Outputs(o), value_name_to_typeinfo) self._output_types[o].add(type_str) - def is_typed_registration_needed( - self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] - ): + def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None): if 0 not in self._input_types: # currently all standard typed registrations are for input 0. # custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below). @@ -262,9 +257,7 @@ def __init__(self, domain: str, optype: str): # init with tracking of input 1 only. super().__init__(domain, optype, inputs=[1], outputs=[]) - def is_typed_registration_needed( - self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] - ): + def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None): return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types) @@ -277,9 +270,7 @@ def __init__(self, domain: str, optype: str): # init with tracking of output 0 only. super().__init__(domain, optype, inputs=[], outputs=[0]) - def is_typed_registration_needed( - self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] - ): + def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None): return self.is_output_type_enabled(type_in_registration, 0, globally_allowed_types) @@ -301,9 +292,7 @@ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): key = (type0, type2, type1) self._triples.add(key) - def is_typed_registration_needed( - self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] - ): + def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None): # the OneHot registration involves a concatenation of the 3 types involved reg_types = tuple([_reg_type_to_cpp_type(reg_type) for reg_type in _split_reg_types(type_in_registration)]) if globally_allowed_types is not None: @@ -633,7 +622,7 @@ class GloballyAllowedTypesOpTypeImplFilter(OpTypeImplFilterInterface): _valid_allowed_types = set(FbsTypeInfo.tensordatatype_to_string.values()) # noqa: RUF012 - def __init__(self, globally_allowed_types: typing.Set[str]): + def __init__(self, globally_allowed_types: set[str]): self._operator_processors = _create_operator_type_usage_processors() if not globally_allowed_types.issubset(self._valid_allowed_types): From d70c8ca7ce206fee68722b594bfb350106eb464c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Jan 2025 12:37:14 -0800 Subject: [PATCH 4/7] from __future__ import annotations --- tools/python/run_CIs_for_branch.py | 1 + tools/python/run_CIs_for_external_pr.py | 1 + tools/python/run_adb.py | 1 + tools/python/run_android_emulator.py | 1 + tools/python/sparsify_initializers.py | 1 + tools/python/util/android/android.py | 1 + tools/python/util/file_utils.py | 1 + tools/python/util/get_azcopy.py | 1 + tools/python/util/make_dynamic_shape_fixed.py | 1 + tools/python/util/mobile_helpers/test/test_usability_checker.py | 1 + tools/python/util/onnx_model_utils.py | 1 + tools/python/util/optimize_onnx_model.py | 1 + tools/python/util/reduced_build_config_parser.py | 1 + tools/python/util/run.py | 1 + 14 files changed, 14 insertions(+) diff --git a/tools/python/run_CIs_for_branch.py b/tools/python/run_CIs_for_branch.py index d1e23f28acdf4..b8d9b9d9d5f72 100644 --- a/tools/python/run_CIs_for_branch.py +++ b/tools/python/run_CIs_for_branch.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import argparse import json diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index faaeb4e5f7127..cee32073fa473 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import argparse import json diff --git a/tools/python/run_adb.py b/tools/python/run_adb.py index 1928966b56a26..aefdb2344d050 100755 --- a/tools/python/run_adb.py +++ b/tools/python/run_adb.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import logging import os diff --git a/tools/python/run_android_emulator.py b/tools/python/run_android_emulator.py index 2826921726556..6d7c29fc58296 100755 --- a/tools/python/run_android_emulator.py +++ b/tools/python/run_android_emulator.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import argparse import contextlib diff --git a/tools/python/sparsify_initializers.py b/tools/python/sparsify_initializers.py index 9232210b5fab2..14f2e0b62c069 100644 --- a/tools/python/sparsify_initializers.py +++ b/tools/python/sparsify_initializers.py @@ -5,6 +5,7 @@ # This script opens an existing model in onnx format and attempts to # move initializers from model.graph.initializer field to model.graph.sparse_initializer field # and convert them into ONNX COO flat index format. +from __future__ import annotations import argparse import logging diff --git a/tools/python/util/android/android.py b/tools/python/util/android/android.py index 13f2b6b8d4952..8f3ed97cae53f 100644 --- a/tools/python/util/android/android.py +++ b/tools/python/util/android/android.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import collections import contextlib diff --git a/tools/python/util/file_utils.py b/tools/python/util/file_utils.py index 20cc15580c9a8..4036841cbfd34 100644 --- a/tools/python/util/file_utils.py +++ b/tools/python/util/file_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import os import pathlib diff --git a/tools/python/util/get_azcopy.py b/tools/python/util/get_azcopy.py index bfcf228a956eb..32ad367b2a010 100644 --- a/tools/python/util/get_azcopy.py +++ b/tools/python/util/get_azcopy.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import contextlib import logging diff --git a/tools/python/util/make_dynamic_shape_fixed.py b/tools/python/util/make_dynamic_shape_fixed.py index f4e09a8cc04a3..2dc89399a604c 100644 --- a/tools/python/util/make_dynamic_shape_fixed.py +++ b/tools/python/util/make_dynamic_shape_fixed.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import argparse import os diff --git a/tools/python/util/mobile_helpers/test/test_usability_checker.py b/tools/python/util/mobile_helpers/test/test_usability_checker.py index 2deacfc91dd1c..7fde729aa0053 100644 --- a/tools/python/util/mobile_helpers/test/test_usability_checker.py +++ b/tools/python/util/mobile_helpers/test/test_usability_checker.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import logging import pathlib diff --git a/tools/python/util/onnx_model_utils.py b/tools/python/util/onnx_model_utils.py index ffefae58471ad..12fff27031e93 100644 --- a/tools/python/util/onnx_model_utils.py +++ b/tools/python/util/onnx_model_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import logging import pathlib diff --git a/tools/python/util/optimize_onnx_model.py b/tools/python/util/optimize_onnx_model.py index b7ebb54b9c8fa..c5459b2d9ff9a 100644 --- a/tools/python/util/optimize_onnx_model.py +++ b/tools/python/util/optimize_onnx_model.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import argparse import os diff --git a/tools/python/util/reduced_build_config_parser.py b/tools/python/util/reduced_build_config_parser.py index be39562e2d60d..0afcca2388f10 100644 --- a/tools/python/util/reduced_build_config_parser.py +++ b/tools/python/util/reduced_build_config_parser.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import os diff --git a/tools/python/util/run.py b/tools/python/util/run.py index 838db8f789eac..b1ebd044f3420 100644 --- a/tools/python/util/run.py +++ b/tools/python/util/run.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import logging import os From c887a1c134af6f87f737a140027a4ccd8a615559 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Jan 2025 12:43:53 -0800 Subject: [PATCH 5/7] Imports in testdata --- .../test/testdata/dynamic_quantize_matmul_test.py | 2 -- onnxruntime/test/testdata/ep_partitioning_tests.py | 1 - onnxruntime/test/testdata/matmul_integer_to_float.py | 2 -- .../test/testdata/sparse_initializer_as_output.py | 8 -------- onnxruntime/test/testdata/sparse_to_dense_matmul.py | 9 --------- .../test/testdata/transform/computation_reduction.py | 2 +- .../computation_reduction/gathernd/gathernd_add.py | 2 +- .../computation_reduction/gathernd/gathernd_div.py | 2 +- .../gathernd/gathernd_layernormalization.py | 2 +- .../computation_reduction/gathernd/gathernd_matmul.py | 2 +- .../test/testdata/transform/concat_slice_elimination.py | 4 +--- onnxruntime/test/testdata/transform/cse/generate.py | 2 +- .../test/testdata/transform/expand_elimination.py | 2 +- .../test/testdata/transform/fusion/attention_gen.py | 1 - .../fusion/constant_folding_with_shape_to_initializer.py | 2 +- onnxruntime/test/testdata/transform/fusion/div_mul.py | 2 -- .../testdata/transform/fusion/dynamic_quantize_matmul.py | 2 -- .../testdata/transform/fusion/embed_layer_norm_gen.py | 2 -- onnxruntime/test/testdata/transform/fusion/fast_gelu.py | 2 +- onnxruntime/test/testdata/transform/fusion/fast_gelu2.py | 2 +- .../testdata/transform/fusion/fast_gelu3_with_casts.py | 2 +- onnxruntime/test/testdata/transform/fusion/gelu_gen.py | 2 +- .../test/testdata/transform/fusion/isinf_reducesum.py | 2 -- .../test/testdata/transform/fusion/layer_norm_t5_gen.py | 2 -- .../testdata/transform/fusion/layer_norm_with_cast_2.py | 3 --- .../testdata/transform/fusion/matmul_integer_to_float.py | 2 -- .../fusion/matmul_integer_to_float_large_tensor.py | 2 -- onnxruntime/test/testdata/transform/fusion/not_where.py | 2 -- onnxruntime/test/testdata/transform/id-elim.py | 3 +-- onnxruntime/test/testdata/transform/id-scan9_sum.py | 3 +-- .../model_parallel/bart_mlp_megatron_basic_test.py | 2 +- .../bart_self_attention_megatron_basic_test.py | 4 +--- .../transform/model_parallel/mlp_megatron_basic_test.py | 2 +- .../model_parallel/self_attention_megatron_basic_test.py | 2 +- tools/python/gen_contrib_doc.py | 9 +-------- 35 files changed, 20 insertions(+), 75 deletions(-) diff --git a/onnxruntime/test/testdata/dynamic_quantize_matmul_test.py b/onnxruntime/test/testdata/dynamic_quantize_matmul_test.py index 8e6dbe5ea581d..594da08abb1fb 100644 --- a/onnxruntime/test/testdata/dynamic_quantize_matmul_test.py +++ b/onnxruntime/test/testdata/dynamic_quantize_matmul_test.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import TensorProto, helper diff --git a/onnxruntime/test/testdata/ep_partitioning_tests.py b/onnxruntime/test/testdata/ep_partitioning_tests.py index 6c8322bb9bd62..367cafb795bad 100644 --- a/onnxruntime/test/testdata/ep_partitioning_tests.py +++ b/onnxruntime/test/testdata/ep_partitioning_tests.py @@ -1,4 +1,3 @@ -import numpy as np # noqa: F401 import onnx from onnx import TensorProto, helper diff --git a/onnxruntime/test/testdata/matmul_integer_to_float.py b/onnxruntime/test/testdata/matmul_integer_to_float.py index e6c51009018f9..0c1ea47fff5b1 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/matmul_integer_to_float.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import TensorProto, helper diff --git a/onnxruntime/test/testdata/sparse_initializer_as_output.py b/onnxruntime/test/testdata/sparse_initializer_as_output.py index 3a7e47910783e..25d66b40a7c73 100644 --- a/onnxruntime/test/testdata/sparse_initializer_as_output.py +++ b/onnxruntime/test/testdata/sparse_initializer_as_output.py @@ -1,21 +1,13 @@ import argparse -import os # noqa: F401 import sys import traceback -from collections.abc import Callable, Sequence # noqa: F401 import numpy as np import onnx from onnx import ( - AttributeProto, # noqa: F401 - GraphProto, # noqa: F401 - SparseTensorProto, # noqa: F401 TensorProto, ValueInfoProto, helper, - mapping, # noqa: F401 - numpy_helper, # noqa: F401 - utils, # noqa: F401 ) from onnx.helper import make_opsetid diff --git a/onnxruntime/test/testdata/sparse_to_dense_matmul.py b/onnxruntime/test/testdata/sparse_to_dense_matmul.py index bbc7f0bc0e88f..5a8a00cc7748e 100644 --- a/onnxruntime/test/testdata/sparse_to_dense_matmul.py +++ b/onnxruntime/test/testdata/sparse_to_dense_matmul.py @@ -1,21 +1,12 @@ import argparse -import os # noqa: F401 import sys import traceback -from collections.abc import Callable, Sequence # noqa: F401 -import numpy as np # noqa: F401 import onnx from onnx import ( - AttributeProto, # noqa: F401 - GraphProto, # noqa: F401 - SparseTensorProto, # noqa: F401 TensorProto, ValueInfoProto, helper, - mapping, # noqa: F401 - numpy_helper, # noqa: F401 - utils, # noqa: F401 ) from onnx.helper import make_opsetid diff --git a/onnxruntime/test/testdata/transform/computation_reduction.py b/onnxruntime/test/testdata/transform/computation_reduction.py index 6f726a54261ed..af0a39636f9ee 100644 --- a/onnxruntime/test/testdata/transform/computation_reduction.py +++ b/onnxruntime/test/testdata/transform/computation_reduction.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper vocab_size = 256 # 30258 diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.py index cd823ce8391c2..7caf7045ccb93 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_add.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) unsqueezed_masked_lm_positions = helper.make_tensor_value_info( diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.py index ee25bef5c1161..86413b8679a56 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_div.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) unsqueezed_masked_lm_positions = helper.make_tensor_value_info( diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.py index dc2abf1dda586..ffaf62a243359 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_layernormalization.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) unsqueezed_masked_lm_positions = helper.make_tensor_value_info( diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.py index bc850c4031741..65767a8986746 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd/gathernd_matmul.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) unsqueezed_masked_lm_positions = helper.make_tensor_value_info( diff --git a/onnxruntime/test/testdata/transform/concat_slice_elimination.py b/onnxruntime/test/testdata/transform/concat_slice_elimination.py index 9eade63328aec..97f0c6f243f60 100644 --- a/onnxruntime/test/testdata/transform/concat_slice_elimination.py +++ b/onnxruntime/test/testdata/transform/concat_slice_elimination.py @@ -1,8 +1,6 @@ -import random # noqa: F401 - import numpy as np import onnx -from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper batch = 3 hidden_size = 4 diff --git a/onnxruntime/test/testdata/transform/cse/generate.py b/onnxruntime/test/testdata/transform/cse/generate.py index ecca4f586f400..01d62422983b5 100644 --- a/onnxruntime/test/testdata/transform/cse/generate.py +++ b/onnxruntime/test/testdata/transform/cse/generate.py @@ -1,7 +1,7 @@ import os import onnx -from onnx import AttributeProto, GraphProto, TensorProto, helper, shape_inference # noqa: F401 +from onnx import TensorProto, helper, shape_inference _this_dir = os.path.abspath(os.path.dirname(__file__)) diff --git a/onnxruntime/test/testdata/transform/expand_elimination.py b/onnxruntime/test/testdata/transform/expand_elimination.py index 86340c9e2553c..226c23fa66389 100644 --- a/onnxruntime/test/testdata/transform/expand_elimination.py +++ b/onnxruntime/test/testdata/transform/expand_elimination.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper X1 = helper.make_tensor_value_info("input1", TensorProto.FLOAT, [2, 1]) X2 = helper.make_tensor_value_info("input2", TensorProto.FLOAT, ["dynamic", 4]) diff --git a/onnxruntime/test/testdata/transform/fusion/attention_gen.py b/onnxruntime/test/testdata/transform/fusion/attention_gen.py index 19f46ab9f358a..6ff0ea5ba9983 100644 --- a/onnxruntime/test/testdata/transform/fusion/attention_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/attention_gen.py @@ -1,5 +1,4 @@ import sys -from enum import Enum # noqa: F401 import onnx from onnx import TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_with_shape_to_initializer.py b/onnxruntime/test/testdata/transform/fusion/constant_folding_with_shape_to_initializer.py index c49ae8b0a422c..65b37a8ed9dab 100644 --- a/onnxruntime/test/testdata/transform/fusion/constant_folding_with_shape_to_initializer.py +++ b/onnxruntime/test/testdata/transform/fusion/constant_folding_with_shape_to_initializer.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper X = helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 4, 8]) Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 4, 16]) diff --git a/onnxruntime/test/testdata/transform/fusion/div_mul.py b/onnxruntime/test/testdata/transform/fusion/div_mul.py index 8cd34a6b53fcf..e7b1f4632afbd 100644 --- a/onnxruntime/test/testdata/transform/fusion/div_mul.py +++ b/onnxruntime/test/testdata/transform/fusion/div_mul.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import OperatorSetIdProto, TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/fusion/dynamic_quantize_matmul.py b/onnxruntime/test/testdata/transform/fusion/dynamic_quantize_matmul.py index 3ec3cabbc8b77..e590b46129d7b 100644 --- a/onnxruntime/test/testdata/transform/fusion/dynamic_quantize_matmul.py +++ b/onnxruntime/test/testdata/transform/fusion/dynamic_quantize_matmul.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py index 54fe7b808bf12..f83bedeb8012c 100644 --- a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import TensorProto, helper from packaging import version diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu.py index 20d78b6684609..a16d7e66752bf 100644 --- a/onnxruntime/test/testdata/transform/fusion/fast_gelu.py +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper # Gelu formula: x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py index 718f924ae5902..6922f3ad0a82a 100644 --- a/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper # Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))) has_bias = False # change it to True to generate fast_gelu_openai_with_bias.onnx diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py index d7cfc351b8e97..d91e186296137 100644 --- a/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper # Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))) diff --git a/onnxruntime/test/testdata/transform/fusion/gelu_gen.py b/onnxruntime/test/testdata/transform/fusion/gelu_gen.py index 428bb0ce00df0..8a4c3ae491215 100644 --- a/onnxruntime/test/testdata/transform/fusion/gelu_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/gelu_gen.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper """ Generate test model for Gelu subgraph pattern 2: diff --git a/onnxruntime/test/testdata/transform/fusion/isinf_reducesum.py b/onnxruntime/test/testdata/transform/fusion/isinf_reducesum.py index c6e70fe478701..a9c88618c5c70 100644 --- a/onnxruntime/test/testdata/transform/fusion/isinf_reducesum.py +++ b/onnxruntime/test/testdata/transform/fusion/isinf_reducesum.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import OperatorSetIdProto, TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_t5_gen.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_t5_gen.py index aa4b78f4525de..c0e2bc85f8248 100644 --- a/onnxruntime/test/testdata/transform/fusion/layer_norm_t5_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_t5_gen.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import OperatorSetIdProto, TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py index 61b2e2249e7a3..fa83290138d87 100644 --- a/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py @@ -1,6 +1,3 @@ -from enum import Enum # noqa: F401 - -import numpy as np # noqa: F401 import onnx from onnx import OperatorSetIdProto, TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py index 018e5fb332dd0..f9b154c46fbd1 100644 --- a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py index 543517cc015ef..6b60a47255c5d 100644 --- a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float_large_tensor.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/fusion/not_where.py b/onnxruntime/test/testdata/transform/fusion/not_where.py index 82a128153ac70..014d0b8fc531a 100644 --- a/onnxruntime/test/testdata/transform/fusion/not_where.py +++ b/onnxruntime/test/testdata/transform/fusion/not_where.py @@ -1,5 +1,3 @@ -from enum import Enum # noqa: F401 - import onnx from onnx import OperatorSetIdProto, TensorProto, helper diff --git a/onnxruntime/test/testdata/transform/id-elim.py b/onnxruntime/test/testdata/transform/id-elim.py index 1f7b6e2607702..eef8011e7fe23 100644 --- a/onnxruntime/test/testdata/transform/id-elim.py +++ b/onnxruntime/test/testdata/transform/id-elim.py @@ -1,6 +1,5 @@ -import numpy as np # noqa: F401 import onnx -from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper X1 = helper.make_tensor_value_info("x1", TensorProto.INT64, [4, 4]) X2 = helper.make_tensor_value_info("x2", TensorProto.INT64, [4, 4]) diff --git a/onnxruntime/test/testdata/transform/id-scan9_sum.py b/onnxruntime/test/testdata/transform/id-scan9_sum.py index 7ffd2e21b7333..c813bbfc18d8e 100644 --- a/onnxruntime/test/testdata/transform/id-scan9_sum.py +++ b/onnxruntime/test/testdata/transform/id-scan9_sum.py @@ -1,6 +1,5 @@ -import numpy as np # noqa: F401 import onnx -from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper initial = helper.make_tensor_value_info("initial", TensorProto.FLOAT, [2]) x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 2]) diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py index 503d860baab67..7879bb4d4e0ff 100644 --- a/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py +++ b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper hidden_size = 4 weight_dim_to_split = 16 diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py index 20bdebead3dac..886cd5c25fb08 100644 --- a/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py +++ b/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py @@ -1,8 +1,6 @@ -import random # noqa: F401 - import numpy as np import onnx -from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper batch = 6 hidden_size = 4 diff --git a/onnxruntime/test/testdata/transform/model_parallel/mlp_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/mlp_megatron_basic_test.py index 07487ee4880ed..5dec4899d59af 100644 --- a/onnxruntime/test/testdata/transform/model_parallel/mlp_megatron_basic_test.py +++ b/onnxruntime/test/testdata/transform/model_parallel/mlp_megatron_basic_test.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper hidden_size = 4 weight_dim_to_split = 16 diff --git a/onnxruntime/test/testdata/transform/model_parallel/self_attention_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/self_attention_megatron_basic_test.py index 306ad7d37403a..3749da038d93e 100644 --- a/onnxruntime/test/testdata/transform/model_parallel/self_attention_megatron_basic_test.py +++ b/onnxruntime/test/testdata/transform/model_parallel/self_attention_megatron_basic_test.py @@ -1,6 +1,6 @@ import numpy as np import onnx -from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # noqa: F401 +from onnx import OperatorSetIdProto, TensorProto, helper, numpy_helper hidden_size = 4 attention_head = 2 diff --git a/tools/python/gen_contrib_doc.py b/tools/python/gen_contrib_doc.py index 54bc920ea3ddd..c190ef3b0ba7d 100644 --- a/tools/python/gen_contrib_doc.py +++ b/tools/python/gen_contrib_doc.py @@ -8,10 +8,8 @@ import pathlib import sys from collections import defaultdict -from collections.abc import Sequence # noqa: F401 -import numpy as np # type: ignore -from onnx import AttributeProto, FunctionProto # noqa: F401 +import numpy as np import onnxruntime.capi.onnxruntime_pybind11_state as rtpy from onnxruntime.capi.onnxruntime_pybind11_state import schemadef # noqa: F401 @@ -305,11 +303,6 @@ def support_level_str(level): # type: (OpSchema.SupportType) -> Text return "experimental " if level == OpSchema.SupportType.EXPERIMENTAL else "" -# def function_status_str(status=OperatorStatus.Value("EXPERIMENTAL")): # type: ignore -# return \ -# "experimental " if status == OperatorStatus.Value('EXPERIMENTAL') else "" # type: ignore - - def main(output_path: str, domain_filter: [str]): with open(output_path, "w", newline="", encoding="utf-8") as fout: fout.write("## Contrib Operator Schemas\n") From 7bf780f27456b6866cae2157ecb0ee3b7087f1d9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Jan 2025 13:24:28 -0800 Subject: [PATCH 6/7] operator_type_usage_processors --- .../util/ort_format_model/operator_type_usage_processors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/python/util/ort_format_model/operator_type_usage_processors.py b/tools/python/util/ort_format_model/operator_type_usage_processors.py index 0e731fd421b0c..53f7a34015060 100644 --- a/tools/python/util/ort_format_model/operator_type_usage_processors.py +++ b/tools/python/util/ort_format_model/operator_type_usage_processors.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import json from abc import ABC, abstractmethod From 371c46196222978bf17e8eb4ca3f54305cbbc2b4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 Jan 2025 15:33:43 -0800 Subject: [PATCH 7/7] from __future__ import annotations --- tools/ci_build/build.py | 1 + tools/ci_build/op_registration_utils.py | 2 ++ tools/ci_build/op_registration_validator.py | 2 ++ tools/ci_build/patch_manylinux.py | 1 + tools/ci_build/reduce_op_kernels.py | 1 + tools/ci_build/replace_urls_in_deps.py | 1 + tools/ci_build/set-trigger-rules.py | 2 +- tools/ci_build/update_tsaoptions.py | 1 + tools/ci_build/upload_python_package_to_azure_storage.py | 1 + 9 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 9d8e44b9cab47..865d1a0c58323 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates # Licensed under the MIT License. +from __future__ import annotations import argparse import contextlib diff --git a/tools/ci_build/op_registration_utils.py b/tools/ci_build/op_registration_utils.py index 0911a16d226f8..d404224a35eea 100644 --- a/tools/ci_build/op_registration_utils.py +++ b/tools/ci_build/op_registration_utils.py @@ -5,6 +5,8 @@ Utilities to help process files containing kernel registrations. """ +from __future__ import annotations + import os import pathlib import sys diff --git a/tools/ci_build/op_registration_validator.py b/tools/ci_build/op_registration_validator.py index b64e4323e8541..6cc7f3bb5ec6d 100644 --- a/tools/ci_build/op_registration_validator.py +++ b/tools/ci_build/op_registration_validator.py @@ -5,6 +5,8 @@ Validate ORT kernel registrations. """ +from __future__ import annotations + import argparse import dataclasses import itertools diff --git a/tools/ci_build/patch_manylinux.py b/tools/ci_build/patch_manylinux.py index 0d1cb37cc40ac..af03b594d9a69 100644 --- a/tools/ci_build/patch_manylinux.py +++ b/tools/ci_build/patch_manylinux.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import argparse import os diff --git a/tools/ci_build/reduce_op_kernels.py b/tools/ci_build/reduce_op_kernels.py index ac26abc5a9d55..f4f5cde3ddf7d 100755 --- a/tools/ci_build/reduce_op_kernels.py +++ b/tools/ci_build/reduce_op_kernels.py @@ -1,6 +1,7 @@ # !/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import argparse import io diff --git a/tools/ci_build/replace_urls_in_deps.py b/tools/ci_build/replace_urls_in_deps.py index 37dad358a6feb..2569b20fb44a5 100644 --- a/tools/ci_build/replace_urls_in_deps.py +++ b/tools/ci_build/replace_urls_in_deps.py @@ -4,6 +4,7 @@ # This file replaces https URLs in deps.txt to local file paths. It runs after we download the dependencies from Azure # DevOps Artifacts +from __future__ import annotations import argparse import csv diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index ae95d30936b83..78f59452d1284 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- # This script is used to add trigger rules to the workflow files. - +from __future__ import annotations import multiprocessing import os diff --git a/tools/ci_build/update_tsaoptions.py b/tools/ci_build/update_tsaoptions.py index 07be746aa1981..394a45cc4ee3b 100644 --- a/tools/ci_build/update_tsaoptions.py +++ b/tools/ci_build/update_tsaoptions.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import json import os diff --git a/tools/ci_build/upload_python_package_to_azure_storage.py b/tools/ci_build/upload_python_package_to_azure_storage.py index 16ff5d1f71611..c90ec1aa92b6b 100755 --- a/tools/ci_build/upload_python_package_to_azure_storage.py +++ b/tools/ci_build/upload_python_package_to_azure_storage.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations import argparse import logging