Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite rules implementation for LLaMA-2/ LLaMA-3 #1811

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

kobby-kobbs
Copy link

Summary

This PR introduces the implementation of LLaMA 3 and LLaMA 2 rewrite rules for the MLP and LLaMAAttention layers in transformers. The rules are designed to work with transformer versions 4.39 to 4.42, and they handle the optimization and fusion operations.

Key Changes

MLP RewriteRule:

A new rewrite rule for optimizing the LLaMA MLP layer (LlamaMLP) in transformer versions 4.39 to 4.42.
The optimization includes handling different input sizes (5 or 6) and performing matrix multiplication and activation operations to produce an optimized output.

GQA Llama RewriteRule:

Introduces a rewrite rule for the LLaMAAttention layer as well as the first attention (LlamaAttention) with support for specified number of inputs.
Two methods are implemented for handling 2D and 4D cache configurations during the Group Query Attention (GQA) process, enabling optimized matrix multiplication and attention operations.

Copy link

codecov bot commented Aug 15, 2024

Codecov Report

Attention: Patch coverage is 25.39062% with 191 lines in your changes missing coverage. Please review.

Project coverage is 73.50%. Comparing base (4c3a6be) to head (03ba9e9).
Report is 88 commits behind head on main.

Files Patch % Lines
...er/onnxruntime/transformers/multihead_attention.py 21.62% 174 Missing ⚠️
onnxscript/rewriter/function_rule.py 56.00% 10 Missing and 1 partial ⚠️
...ipt/rewriter/onnxruntime/transformers/layernorm.py 37.50% 5 Missing ⚠️
onnxscript/ir/_core.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1811      +/-   ##
==========================================
- Coverage   75.95%   73.50%   -2.45%     
==========================================
  Files         228      248      +20     
  Lines       24246    26893    +2647     
  Branches     4201     4915     +714     
==========================================
+ Hits        18416    19768    +1352     
- Misses       5035     6161    +1126     
- Partials      795      964     +169     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

initializer.py Fixed Show fixed Hide fixed
initializer.py Fixed Show fixed Hide fixed
initializer.py Fixed Show fixed Hide fixed
initializer.py Fixed Show fixed Hide fixed
testingg.py Fixed Show fixed Hide fixed
testingg.py Fixed Show fixed Hide fixed
testingg.py Fixed Show fixed Hide fixed
testingg.py Fixed Show fixed Hide fixed
onnx_model = onnx.load(output_model_path, load_external_data=False)

# Apply the inliner
onnx_model = onnx.inliner.inline_local_functions(onnx_model)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable onnx_model is not used.
initializer.py Outdated
@@ -0,0 +1,231 @@

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest excluding this file for now. We can focus on the rewriter rules for this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

testingg.py Fixed Show fixed Hide fixed
@justinchuby
Copy link
Collaborator

Congrats on your first PR! 🎉 For autofix-able lint errors, you can follow https://github.com/microsoft/onnxscript#coding-style to run the autofix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants