Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fusions for SigLIP and Conformer-Encoder #23528

Merged
merged 21 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
52aee74
Add fusions for vision component
kunal-vaishnavi Jan 14, 2025
e51c299
Add fusions for vision and speech components
kunal-vaishnavi Jan 20, 2025
c46e180
Remove commented out code
kunal-vaishnavi Jan 20, 2025
6308fe0
Fix order of mask nodes
kunal-vaishnavi Jan 22, 2025
2ff6b94
Add helper method for transpose initializers
kunal-vaishnavi Jan 23, 2025
458d1ec
Add unit tests for Dynamo constant folding
kunal-vaishnavi Jan 25, 2025
031370c
Merge branch 'main' into kvaishnavi/phi4-mm
kunal-vaishnavi Jan 25, 2025
5c0e5ab
Fix issues after merge
kunal-vaishnavi Jan 25, 2025
4e57caf
Update type hints for newer Python versions
kunal-vaishnavi Jan 25, 2025
cb3247c
Add matching for extra cast nodes
kunal-vaishnavi Jan 25, 2025
3a3e196
Fix input name replacement for useless cast node removal
kunal-vaishnavi Jan 26, 2025
1f80a9a
Merge branch 'main' into kvaishnavi/phi4-mm
kunal-vaishnavi Jan 29, 2025
d8df5a0
Add changes suggested by linter
kunal-vaishnavi Jan 30, 2025
5442b50
Address some review comments
kunal-vaishnavi Jan 30, 2025
f00c04e
Merge remote-tracking branch 'refs/remotes/origin/kvaishnavi/phi4-mm'…
kunal-vaishnavi Jan 30, 2025
3ddf445
Update path matching approach based on feedback
kunal-vaishnavi Jan 30, 2025
9231df2
Add more changes suggested by linter
kunal-vaishnavi Jan 30, 2025
e420343
Fix QK nodes matching to match other styles
kunal-vaishnavi Jan 30, 2025
252f007
Fix unit test failures
kunal-vaishnavi Jan 30, 2025
8696070
Add format change suggested by linter
kunal-vaishnavi Jan 30, 2025
90b871f
Use logger debug instead of logger info
kunal-vaishnavi Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 117 additions & 16 deletions onnxruntime/python/tools/transformers/dynamo_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,61 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from collections.abc import Sequence
Fixed Show fixed Hide fixed
from logging import getLogger
from typing import Any

import numpy as np
import onnx
from onnx import helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class DynamoOnnxHelper:
"""
Helper class for processing ONNX models exported by torch Dynamo.
Helper class for processing ONNX models exported by Torch Dynamo.
"""

def __init__(self, model: onnx.ModelProto):
self.model = model
self.model = OnnxModel(model)

def update_edges(self, edge_mapping: dict) -> None:
"""
Updates the edges in the model according to the given mapping.
"""
for node in self.model.graph.node:
for node in self.model.model.graph.node:
for i in range(len(node.input)):
if node.input[i] in edge_mapping:
node.input[i] = edge_mapping[node.input[i]]
for i in range(len(node.output)):
if node.output[i] in edge_mapping:
node.output[i] = edge_mapping[node.output[i]]

for graph_input in self.model.graph.input:
for graph_input in self.model.model.graph.input:
if graph_input.name in edge_mapping:
graph_input.name = edge_mapping[graph_input.name]
for graph_output in self.model.graph.output:
for graph_output in self.model.model.graph.output:
if graph_output.name in edge_mapping:
graph_output.name = edge_mapping[graph_output.name]

def unroll_function(self, func_name: str) -> None:
"""
Unrolls the function with the given name in the model.
"""
logging.info(f"Unrolling function {func_name}...")
logger.debug(f"Unrolling function {func_name}...")
nodes_to_remove = []
nodes_to_add = []
edges_to_remove = []
edges_to_add = []
for node in self.model.graph.node:
for node in self.model.model.graph.node:
if node.op_type == func_name:
nodes_to_remove.append(node)
edges_to_remove.extend(list(node.input) + list(node.output))

func_to_remove = None
for f in self.model.functions:
for f in self.model.model.functions:
if f.name == func_name:
nodes_to_add.extend(list(f.node))
edges_to_add.extend(list(f.input) + list(f.output))
Expand All @@ -58,11 +65,11 @@ def unroll_function(self, func_name: str) -> None:
assert len(edges_to_remove) == len(edges_to_add)

for node in nodes_to_remove:
self.model.graph.node.remove(node)
self.model.model.graph.node.remove(node)
for node in nodes_to_add:
self.model.graph.node.append(node)
self.model.model.graph.node.append(node)
if func_to_remove is not None:
self.model.functions.remove(func_to_remove)
self.model.model.functions.remove(func_to_remove)

edge_mapping = {}
for i in range(len(edges_to_remove)):
Expand All @@ -79,26 +86,120 @@ def remove_function(self, func_name: str, input_id: int, output_id: int) -> None
"""
edge_mapping = {}
nodes_to_remove = []
for node in self.model.graph.node:
for node in self.model.model.graph.node:
if node.op_type.find(func_name) != -1:
edge_mapping[node.input[input_id]] = node.output[output_id]
nodes_to_remove.append(node)
for node in nodes_to_remove:
self.model.graph.node.remove(node)
self.model.model.graph.node.remove(node)

self.update_edges(edge_mapping)

def remove_dropout_layer(self) -> None:
"""
Removes the dropout layer in the model.
"""
logging.info("Removing dropout layer...")
logger.debug("Removing dropout layer...")
self.remove_function("Dropout", 0, 0)

def remove_lm_head_layer(self) -> None:
"""
Removes the LM head layer in the model.
"""
logging.info("Removing LM head layer...")
logger.debug("Removing LM head layer...")
# bugbug: need to copy the right vi over
self.remove_function("Linear_lm_head", 2, 0)

def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
if raw:
np_type = helper.tensor_dtype_to_np_dtype(data_type)
if not isinstance(vals, np.ndarray):
bytes = np.array(vals, dtype=np_type).tobytes()
else:
bytes = vals.astype(np_type).tobytes()
tensor = helper.make_tensor(
name=name,
data_type=data_type,
dims=dims,
vals=bytes,
raw=True,
)
else:
tensor = helper.make_tensor(
name=name,
data_type=data_type,
dims=dims,
vals=vals,
raw=False,
)

self.model.add_initializer(tensor)
return tensor

def convert_constants_to_initializers(self, min_size: int = 1) -> None:
"""
Converts Constant ops of size [min_size] or higher to initializers
"""
logger.debug(f"Converting constants greater than size {min_size} to initializers")

constant_nodes = self.model.get_nodes_by_op_type("Constant")
nodes_to_remove = []

for node in constant_nodes:
# Get info from Constant op
np_data = self.model.get_constant_value(node.output[0])

# Skip if there are less than [min_size] elements
if np_data is None or np_data.size < min_size:
continue

# Add new initializer with same name as Constant op's output
for att in node.attribute:
if att.name == "value":
self.add_initializer(
name=node.output[0],
data_type=att.t.data_type,
dims=list(np_data.shape),
vals=np_data,
)
break

nodes_to_remove.append(node)

# Remove Constant ops from graph
self.model.remove_nodes(nodes_to_remove)

def clear_metadata(self) -> None:
"""
Clear metadata fields in all nodes
"""
for graph in self.model.graphs():
graph.ClearField("metadata_props")
for node in self.model.nodes():
node.ClearField("metadata_props")

@staticmethod
def fold_transpose_initializers(model) -> None:
"""
Constant fold Transpose initializers without changing the initializer names
"""
from onnxscript import ir

for name, initializer in model.graph.initializers.items():
user_nodes = initializer.consumers()
if len(user_nodes) == 1 and user_nodes[0].op_type == "Transpose":
transpose_node = user_nodes[0]
perm = transpose_node.attributes.get("perm")
if perm is None:
transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose())
else:
transposed_tensor = ir.tensor(initializer.const_value.numpy().transpose(perm.as_ints()))
new_initializer = ir.Value(
name=initializer.name,
shape=transposed_tensor.shape,
type=ir.TensorType(transposed_tensor.dtype),
const_value=transposed_tensor,
)
ir.convenience.replace_all_uses_with(transpose_node.outputs[0], new_initializer)
model.graph.initializers[name] = new_initializer
transpose_node.graph.remove(transpose_node, safe=True)
6 changes: 2 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,12 +838,10 @@ def create_attention_node(
attention_inputs.append(past_kv)

if add_qk_str:
mask_output_name = self.reshape_add_qk(add_qk_str)

# Add attention mask to attention node
# Add additional add to attention node (input name = attention_bias)
if not past_exists:
attention_inputs.append("")
attention_inputs.append(mask_output_name)
attention_inputs.append(add_qk_str)

attention_outputs = [output]
if present_k and present_v:
Expand Down
Loading
Loading