-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[draft] Xiaowu/fix bug(embedding bag) #1099
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #1099 +/- ##
===========================================
- Coverage 78.55% 56.93% -21.63%
===========================================
Files 118 106 -12
Lines 15154 12752 -2402
Branches 1620 1331 -289
===========================================
- Hits 11904 7260 -4644
- Misses 2872 5189 +2317
+ Partials 378 303 -75 ☔ View full report in Codecov by Sentry. |
Test Results 18 files ± 0 18 suites ±0 26m 13s ⏱️ - 37m 32s For more details on these errors, see this check. Results for commit 95a24dd. ± Comparison against base commit 10f9a1f. This pull request removes 3639 and adds 2403 tests. Note that renamed tests count towards both.
This pull request removes 1409 skipped tests and adds 459 skipped tests. Note that renamed tests count towards both.
♻️ This comment has been updated with latest results. |
print(max_indices) | ||
|
||
def test_embedding_bag_aten(): | ||
import torch as t |
Check notice
Code scanning / lintrunner
PYLINT/C0415 Note
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel
print(max_indices) | ||
|
||
def test_embedding_bag_nn_function(): | ||
import torch as t |
Check notice
Code scanning / lintrunner
PYLINT/C0415 Note
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel
[-4.2188, -4.2266, -2.7246, -6.8555, -7.6719]], dtype=t.float16) | ||
indices = t.tensor([4, 9, 3, 0, 3], dtype=t.int64) | ||
offsets = t.tensor([0, 3], dtype=t.int64) | ||
mode = 0 # sum |
Check warning
Code scanning / lintrunner
PYLINT/W0612 Warning
See unused-variable. To disable, use # pylint: disable=unused-variable
[-4.2188, -4.2266, -2.7246, -6.8555, -7.6719]], dtype=t.float16) | ||
indices = t.tensor([4, 9, 3, 0, 3], dtype=t.int64) | ||
offsets = t.tensor([0, 3], dtype=t.int64) | ||
mode = 0 # sum |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning
See https://beta.ruff.rs/docs/rules/
@@ -3046,6 +3046,104 @@ | |||
return result, offset2bag, bag_size, max_indices | |||
|
|||
|
|||
|
|||
def test_embedding_bag_onnx(): | |||
import numpy as np |
Check notice
Code scanning / lintrunner
PYLINT/C0415 Note
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel
# include_last_offset = True | ||
per_sample_weights = np.array([2.4134, -0.1783, 7.1360, -0.7987, 2.3815], dtype=np.float16) | ||
#per_sample_weights = np.array([-2.2930, 6.2148, 3.1562, 0.0791, 6.3555], dtype=np.float16) | ||
result1, offset2bag, bag_size, max_indices = aten_embedding_bag(weight, indices, offsets, mode=mode, per_sample_weights=per_sample_weights) |
Check warning
Code scanning / lintrunner
PYLINT/W0612 Warning
See unused-variable. To disable, use # pylint: disable=unused-variable
# include_last_offset = True | ||
per_sample_weights = np.array([2.4134, -0.1783, 7.1360, -0.7987, 2.3815], dtype=np.float16) | ||
#per_sample_weights = np.array([-2.2930, 6.2148, 3.1562, 0.0791, 6.3555], dtype=np.float16) | ||
result1, offset2bag, bag_size, max_indices = aten_embedding_bag(weight, indices, offsets, mode=mode, per_sample_weights=per_sample_weights) |
Check warning
Code scanning / lintrunner
PYLINT/W0612 Warning
See unused-variable. To disable, use # pylint: disable=unused-variable
# include_last_offset = True | ||
per_sample_weights = np.array([2.4134, -0.1783, 7.1360, -0.7987, 2.3815], dtype=np.float16) | ||
#per_sample_weights = np.array([-2.2930, 6.2148, 3.1562, 0.0791, 6.3555], dtype=np.float16) | ||
result1, offset2bag, bag_size, max_indices = aten_embedding_bag(weight, indices, offsets, mode=mode, per_sample_weights=per_sample_weights) |
Check warning
Code scanning / lintrunner
PYLINT/W0612 Warning
See unused-variable. To disable, use # pylint: disable=unused-variable
#test_embedding_bag_aten() | ||
#test_embedding_bag_nn_function() | ||
|
||
exit(0) |
Check notice
Code scanning / lintrunner
PYLINT/R1722 Note
See consider-using-sys-exit. To disable, use # pylint: disable=consider-using-sys-exit
from #1056 this, copy sample=7 data
we can got same result from test_embedding_bag_onnx() and test_embedding_bag_aten().
result from onnx-script:
result from aten: