-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unit test to investigate torch issues (scaled_dot_product_attention, index_put) #1864
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: xadupre <[email protected]>
Signed-off-by: xadupre <[email protected]>
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1864 +/- ##
==========================================
+ Coverage 75.26% 75.32% +0.06%
==========================================
Files 251 251
Lines 27446 27446
Branches 5032 5032
==========================================
+ Hits 20656 20673 +17
+ Misses 5822 5808 -14
+ Partials 968 965 -3 ☔ View full report in Codecov by Sentry. |
Signed-off-by: xadupre <[email protected]>
key_states = torch.randn(batch_size, seq_length_kv, embedding_dim) | ||
value_states = torch.randn(batch_size, seq_length_kv, embedding_dim) | ||
|
||
output = model(query_states, key_states, value_states) |
Check warning
Code scanning / CodeQL
Variable defined multiple times Warning test
redefined
else: | ||
raise AssertionError(f"Unknown exporter {exporter!r}") | ||
|
||
import onnxruntime |
Check notice
Code scanning / lintrunner
PYLINT/C0415 Note test
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel
onnx_file_path = f"scaled_dot_product_attention_{exporter}.onnx" | ||
|
||
if exporter == "script": | ||
torch.onnx.export( |
Check failure
Code scanning / lintrunner
PYLINT/E1123 Error test
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg
else: | ||
raise AssertionError(f"Unknown exporter {exporter!r}") | ||
|
||
import onnxruntime |
Check notice
Code scanning / lintrunner
PYLINT/C0415 Note test
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel
Do you have plans to merge or is this for investigation only? Marking as draft for now |
We do need to rewrite our index put implementation. #1749 |
I don't have time to implement a fix this week but anybody doing it should check the with unit tests I made and decide whether or not they should be kept. |
See pytorch/pytorch#135615, pytorch/pytorch#135233.