Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve style | feat(codegen) #899

Draft
wants to merge 1 commit into
base: gh/justinchuby/35/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions onnxscript/codeanalysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-ancestors
# --------------------------------------------------------------------------

from __future__ import annotations

import dataclasses
import os
import pathlib
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Final, Protocol, Sequence, runtime_checkable

import libcst as cst
Expand All @@ -28,12 +25,12 @@
]


def format_code(path: Path | None, code: bytes) -> bytes:
def format_code(path: pathlib.Path | None, code: bytes) -> bytes:
try:
import ufmt
import ufmt # pylint: disable=import-outside-toplevel

if path is None:
path = Path(os.curdir)
path = pathlib.Path(os.curdir)

return ufmt.ufmt_bytes(
path,
Expand Down Expand Up @@ -84,7 +81,7 @@ def make_const_expr(const: str | int | float) -> cst.BaseExpression:
return val


@dataclass
@dataclasses.dataclass
class ImportAlias:
name: str
alias: str | None = None
Expand All @@ -95,15 +92,15 @@ def to_cst(self) -> cst.ImportAlias:
)


@dataclass
@dataclasses.dataclass
class Import:
module: ImportAlias

def to_cst(self) -> cst.Import:
return cst.Import(names=[self.module.to_cst()])


@dataclass
@dataclasses.dataclass
class ImportFrom:
module: str
names: list[ImportAlias]
Expand All @@ -121,7 +118,9 @@ def analyze_scopes(self, scopes: set[cstmeta.Scope]):
pass


class RemoveUnusedImportsTransformer(cst.CSTTransformer, ScopeAnalyzer):
class RemoveUnusedImportsTransformer(
cst.CSTTransformer, ScopeAnalyzer
): # pylint: disable=too-many-ancestors
def __init__(self):
self.__unused_imports: dict[cst.Import | cst.ImportFrom, set[str]] = defaultdict(set)

Expand Down
12 changes: 6 additions & 6 deletions onnxscript/codeanalysis/onnx_to_onnxscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from __future__ import annotations

import dataclasses
import pathlib
import re
from dataclasses import dataclass
from pathlib import Path
from typing import BinaryIO, Final, Literal, Sequence, cast, overload

import libcst as cst
Expand All @@ -33,7 +33,7 @@
DEFAULT_OPSET_VERSION: Final = 18


@dataclass
@dataclasses.dataclass
class QualifiedOnnxOp:
domain: str
name: str
Expand Down Expand Up @@ -501,10 +501,10 @@ class Driver:

def __init__(
self,
model: onnx.ModelProto | Path | str | BinaryIO,
model: onnx.ModelProto | pathlib.Path | str | BinaryIO,
transformers: Sequence[cst.CSTTransformer] | None = None,
):
if isinstance(model, Path):
if isinstance(model, pathlib.Path):
model = str(model.resolve())
if not isinstance(model, onnx.ModelProto):
model = onnx.load_model(model)
Expand All @@ -526,7 +526,7 @@ def to_cst_module(self) -> cst.Module:
cst_module = codegen.apply_transformers(cst_module, self.transformers)
return cst_module

def to_python_code(self, reference_path: Path | None = None) -> bytes:
def to_python_code(self, reference_path: pathlib.Path | None = None) -> bytes:
return format_code(
path=reference_path,
code=self.to_cst_module().bytes,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["numpy", "onnx>=1.14", "typing_extensions"]
dependencies = ["numpy", "onnx>=1.14", "typing_extensions", "libcst"]

[tool.setuptools.packages.find]
include = ["onnxscript*"]
Expand Down Expand Up @@ -60,6 +60,7 @@ module = [
"onnxruntime.*",
"parameterized.*",
"torchgen.*",
"ufmt.*",
]
ignore_missing_imports = true

Expand Down