Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite rules implementation for LLaMA-2/ LLaMA-3 #1811

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@

class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
__slots__ = ("_value",)

Check warning on line 756 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L756

Added line #L756 was not covered by tests
def __init__(self, value: str | None) -> None:
"""Initialize a symbolic dimension.

Expand Down
55 changes: 30 additions & 25 deletions onnxscript/rewriter/function_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,15 @@

class FunctionRewriteRule(pattern.RewriteRule):
FUNCTION_KEYWORD: str | tuple[str]
"""The keyword to match the function name. If a tuple, any keyword will match."""

PACKAGE_NAME: str
"""The package name to match.

For example, 'transformers' to match for domain name 'pkg.transformers.4.36.2'.
"""

_opset_imports: dict[str, int]
onnx_opset: onnxscript.values.Opset

def __init__(self, opset: onnxscript.values.Opset = onnxscript.opset18) -> None: # type: ignore[has-type]
def __init__(self, opset: onnxscript.values.Opset = onnxscript.opset18) -> None:

Check failure

Code scanning / lintrunner

MYPY/has-type Error

Cannot determine type of "opset18" To disable, use # type: ignore[has-type]
self.onnx_opset = opset

def _match_function(self, function: ir.Function, pkg_name: str) -> bool:
# TODO: Consolidate more checks from `compose_new_function` to here.
print("----> Checking function:", function.name, "in package:", pkg_name)
if pkg_name != self.PACKAGE_NAME:
logger.info(
"Rule %s did not match function %s::%s. Package name mismatch '%s' != '%s'.",
Expand All @@ -113,11 +106,15 @@
)
return False
if isinstance(self.FUNCTION_KEYWORD, str):
return function.name.find(self.FUNCTION_KEYWORD) != -1
match = function.name.find(self.FUNCTION_KEYWORD) != -1
print(f"----> Function name '{function.name}' match with '{self.FUNCTION_KEYWORD}': {match}")
return match
elif isinstance(self.FUNCTION_KEYWORD, tuple):
return any(function.name.find(keyword) != -1 for keyword in self.FUNCTION_KEYWORD)
match = any(function.name.find(keyword) != -1 for keyword in self.FUNCTION_KEYWORD)
print(f"----> Function name '{function.name}' match with any of '{self.FUNCTION_KEYWORD}': {match}")
return match

Check warning on line 115 in onnxscript/rewriter/function_rule.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/function_rule.py#L114-L115

Added lines #L114 - L115 were not covered by tests
else:
raise ValueError( # noqa: TRY004
raise ValueError(

Check warning

Code scanning / lintrunner

RUFF/TRY004 Warning

Prefer TypeError exception for invalid type.
See https://docs.astral.sh/ruff/rules/type-check-without-type-error

Check warning on line 117 in onnxscript/rewriter/function_rule.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/function_rule.py#L117

Added line #L117 was not covered by tests
f"Function keyword must be str or tuple, got {self.FUNCTION_KEYWORD}"
)

Expand All @@ -129,10 +126,17 @@
return node
return None

def _find_function_by_name(
self, function: ir.Function, keyword: str
) -> ir.Function | None:
for node in function:
if node.name.find(keyword) != -1:

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "str | None" has no attribute "find" To disable, use # type: ignore[union-attr]
return node

Check failure

Code scanning / lintrunner

MYPY/return-value Error

Incompatible return value type (got "Node", expected "Function | None") To disable, use # type: ignore[return-value]
return None

Check warning on line 135 in onnxscript/rewriter/function_rule.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/function_rule.py#L134-L135

Added lines #L134 - L135 were not covered by tests

def _find_node_by_type(
self, function: ir.Function, domain: str, op_type: str
) -> ir.Node | None:
# Repeat
for node in function:
if node.domain == domain and node.op_type == op_type:
return node
Expand All @@ -141,18 +145,12 @@
def compose_new_function(
self, old_function: ir.Function, pkg_version: version.Version | None
) -> ir.Function:
"""Compose a new function from the old function.

Returns:
A tuple of the new function and the opset imports.

Raises:
FunctionRewriteError: If the rewrite fails.
"""
# self._version_controller is created in the subclass
func = self._version_controller.dispatch(pkg_version) # type: ignore[attr-defined]
print("----> (2) pkg_version", pkg_version, "old_function", old_function.name)
func = self._version_controller.dispatch(pkg_version)

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"FunctionRewriteRule" has no attribute "_version_controller" To disable, use # type: ignore[attr-defined]
if func is not None:
print("----> (2.5) Dispatch function found, applying...")
new_function = func(self, old_function)
print("----> (2.6) New function created.")
return new_function
raise FunctionRewriteError(
f"No rewrite implementation for package version {pkg_version}."
Expand All @@ -163,6 +161,7 @@
) -> tuple[ir.OperatorIdentifier, ir.Function] | None:
try:
pkg_name, pkg_version = parse_domain(function.domain)
print("----> (1) Parsed domain, pkg_name:", pkg_name, "pkg_version:", pkg_version)
except FunctionRewriteError as e:
logger.warning("Could not parse domain: %s", e)
return None
Expand All @@ -177,33 +176,39 @@
)

if not self._match_function(function, pkg_name):
print("----> (1.5) Function does not match.")
return None
logger.info(
"Rule %s matched function %s::%s",
self.__class__.__name__,
function.domain,
function.name,
)
print("----> (1.6) Function matched.")
try:
new_function = self.compose_new_function(function, pkg_version)
except FunctionRewriteError as e:
logger.warning("Could not rewrite function: %s", e)
return None

if not hasattr(new_function, 'name'):
logger.error("new_function does not have a 'name' attribute. Received: %s", type(new_function))
return None

Check warning on line 196 in onnxscript/rewriter/function_rule.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/function_rule.py#L195-L196

Added lines #L195 - L196 were not covered by tests

new_function.name = function.name
new_function.domain = function.domain

return function.identifier(), new_function

def try_rewrite(self, model: ir.Model, value) -> bool:
raise NotImplementedError(
"Use `try_rewrite_function` instead for function based rewrites."
"Use try_rewrite_function instead for function based rewrites."
)

def apply_to_model(
self, model: ir.Model, *, commute: bool = False
) -> tuple[int, ir.Model]:
del commute # unused
del commute

old_function_to_new_function: dict[ir.OperatorIdentifier, ir.Function] = {}
for function in model.functions.values():
Expand Down
8 changes: 8 additions & 0 deletions onnxscript/rewriter/onnxruntime/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,12 @@
layernorm.LNRewriteRule,
fastgelu.GeluRewriteRule,
biassplitgelu.GegluRewriteRule,
multihead_attention.MHALlama2RewriteRule,
multihead_attention.GQALlama3RewriteRule,
multihead_attention.AttentionRewriteRule,
multihead_attention.MLP3RewriteRule,
multihead_attention.GQALlama3RewriteRuleFirstAttention,
multihead_attention.MLPRewriteRule,
multihead_attention.GQALlamaRewriteRule,

]
23 changes: 9 additions & 14 deletions onnxscript/rewriter/onnxruntime/transformers/layernorm.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging

import onnxscript
from onnxscript import ir
from onnxscript.rewriter import _ir_utils, function_rule

import logging
logger = logging.getLogger(__name__)


class LNRewriteRule(function_rule.FunctionRewriteRule):
FUNCTION_KEYWORD = "layernorm"
FUNCTION_KEYWORD = "norm"
PACKAGE_NAME = "transformers"
_version_controller = function_rule.VersionController()

@_version_controller.register_version()
@_version_controller.register_version(min_version="4.40", max_version="4.50")
def _fusion(self, function: ir.Function) -> ir.Function:
# TODO(bowbao): Might be more desirable to annotate as attribute in nn.Module
aten_add_node = self._find_node_by_type(function, "", "Add")
# depending on graph, you may have to find node by name or type instead of function
aten_add_node = self._find_function_by_name(function, "aten_add")

Check warning on line 18 in onnxscript/rewriter/onnxruntime/transformers/layernorm.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/layernorm.py#L18

Added line #L18 was not covered by tests
if aten_add_node is None:
raise function_rule.FunctionRewriteError("Could not find Add node")

raise function_rule.FunctionRewriteError("Could not find Add node")

Check warning on line 21 in onnxscript/rewriter/onnxruntime/transformers/layernorm.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/layernorm.py#L21

Added line #L21 was not covered by tests
eps_ir_value = _ir_utils.propagate_const_value(aten_add_node.inputs[1])
eps_const_value = eps_ir_value.const_value
if eps_const_value is None:
print("could not find")

Check warning on line 25 in onnxscript/rewriter/onnxruntime/transformers/layernorm.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/layernorm.py#L25

Added line #L25 was not covered by tests
raise function_rule.FunctionRewriteError("Could not find eps")
eps_numpy_value = eps_const_value.numpy()
eps = eps_numpy_value.item()
logger.info("eps: %s", eps)

# TODO(ORT): SimplifiedLayerNormalization in ort is defined under onnx domain.
# https://github.com/microsoft/onnxruntime/issues/7573
# msft_op = onnxscript.values.Opset("com.microsoft", 1)
op = self.onnx_opset
msft_op = onnxscript.values.Opset("com.microsoft", 1)

Check warning on line 32 in onnxscript/rewriter/onnxruntime/transformers/layernorm.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/layernorm.py#L32

Added line #L32 was not covered by tests

def ln(input, weight):
return op.SimplifiedLayerNormalization(
return msft_op.SimplifiedLayerNormalization(

Check warning on line 35 in onnxscript/rewriter/onnxruntime/transformers/layernorm.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/transformers/layernorm.py#L35

Added line #L35 was not covered by tests
input, weight, axis=-1, epsilon=eps, stash_type=1
)

Expand Down
Loading
Loading