-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tranposing to different layout permutations results in different numerics #17276
Comments
unit-test reproducer (that we also modified to test tranpose as the root of modules):
|
@akuegel is my understanding right that transposes should always use the default layout and that's normally ensured by the layout normalization? If so, should we try to detect the wrong ones at codegen or in the HLO verifier? |
It's indeed expected that cuBLAS gemms get the layout normalized, and if you feed the HLO to the optimization passes, it already fails (with slightly different error):
We can add a check somewhere that ensures that layout that gets into the custom call is normalized, but that's just to ensure internal invariant, it's (at least in theory) not possible to get to this state from a pre-optimized HLO. |
It's not about cuBLAS, it's about transpose alone, see the second reproducer in the first message. |
@sergachev While Layout Normalization will make sure that transposes have the default layout, there could be passes later in the pipeline that create transposes with non-default layout. Note that anything that calls MakeTransposeHlo from hlo_creation_utils will most likely have a non-default layout, as that function infers a layout that will make the transpose a bitcast. This is something we want to avoid, so if you see any pass that runs after LayoutNormalization that calls MakeTransposeHlo, please file a bug or send a PR. |
@akuegel it sounds to me, generally speaking we should ensure that layout normalization and its associated passes are called after all rewriters and op-changing passes, before codegen, to ensure they have accounted for all ops? As the layouts normalized by the pass might be a strict requirement |
We already have HLO passes that rely on having only transposes with default layout. For example the one I added recently (TransposeDimensionGrouper) only works on transposes with default layout and will return an error otherwise. So just running the layout normalization once again at the end of the pipeline will not fix the issue. So the suggestion of @sergachev to make it part of the HloVerifier sounds better to me. It would need to be a verifier option that is off by default, but can be turned on in our pipeline after LayoutNormalization pass. |
OK that sounds good! My original comment was more focused on other instructions involved in the layout normalization in a broader sense. Are all the instructions that the layout normalization pass standardizes considered a strict requirement? Or maybe transpose is a special case that we stumbled upon that would affect correctness |
Once LayoutNormalization has run, it is quite unlikely that other passes will introduce ops that don't have the default layout. Normally the layout of new ops is derived from the ops surrounding it, so if all those ops have the default layout, the new ops will have the default layout as well. Transpose is special because of the MakeTransposeHlo() method, because that will choose a non-default layout. I believe it was a mistake to make that function assign a non-default layout, but that would probably be quite hard to change now. |
The layout of new ops is indeed derived from the ops surrounding it, and the "bug" is due to some of ops don't have a chance to go through layoutnormalization pass: Triton first fuses FP8 GEMMs, but during layout normalization, the tranpose has not yet being inserted by GemmRewriter and ops within these fusions are not handled either. Then when the autotuner falls back to cublas, where the fused computations are inlined, cublas GemmRewriter might insert a non-default tranpose based on the context. In this situation, would you consider it a bug where layout normalization should also occur after inlining the computations, or should we better insert a non-default tranpose in the GemmRewriter? |
Ideally we would insert a transpose with default layout in the GemmRewriter. If you have a transpose that preserves the non-default layout of its operand, it can be normalized to have a default layout by adding a bitcast transpose in front and after it. Unfortunately we still don't normalize Dots, which means we often have a bitcast operand of a dot with non-default layout, so if a transpose is inserted between the bitcast and the dot, it would have non-default layout as well. |
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b479c2177672a0010ffba1630efdaec5ca4cee26 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b479c2177672a0010ffba1630efdaec5ca4cee26 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b479c2177672a0010ffba1630efdaec5ca4cee26 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b479c2177672a0010ffba1630efdaec5ca4cee26 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b479c21 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose 824ac54 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose 824ac5425f1529326086c86f1cc7f31eee1fee9b PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose 824ac54 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose 824ac54 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix -- b633184 by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b633184 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix -- b63318487153a8668b9f95574b054b0129194c0c by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b63318487153a8668b9f95574b054b0129194c0c PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix -- b633184 by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b633184 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix -- b633184 by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b633184 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix -- b633184 by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b633184 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix -- b63318487153a8668b9f95574b054b0129194c0c by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b63318487153a8668b9f95574b054b0129194c0c PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix -- b633184 by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b633184 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix -- b633184 by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b633184 PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix -- b63318487153a8668b9f95574b054b0129194c0c by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17440 from wenscarl:fp8_regulate_transpose b63318487153a8668b9f95574b054b0129194c0c PiperOrigin-RevId: 680886834
… rewrite Imported from GitHub PR #17440 Related to #17276 and #16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c032 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd69 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9 by shuw <[email protected]>: clang format -- ad0a4ba by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d by Shu Wang <[email protected]>: Remove uncessary space. -- 7837845 by Shu Wang <[email protected]>: Update unittest. -- b479c21 by shuw <[email protected]>: Improve TransposeMatrix -- b633184 by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 COPYBARA_INTEGRATE_REVIEW=#17440 from wenscarl:fp8_regulate_transpose b633184 PiperOrigin-RevId: 681551009
… rewrite Imported from GitHub PR openxla/xla#17440 Related to openxla/xla#17276 and openxla/xla#16975. This PR updates the GemmRewriter to handle the transpose of non-descending layouts directly, eliminating the need for the layout_normalization pass to correct this error-prone pattern post-rewrite. The desired transformation is now injected into GemmRewriter, ensuring the problematic layout is handled internally. This PR transforms the following error-prone pattern, where the transpose of a non-descending layout is the issue: ``` a = f8e4m3fn[x,y]{0,1} xxx transpose.0 = f8e4m3fn[y,x]{0,1} transpose(a), dimensions=(1,0) custom-call(a,...) ``` to ``` a = f8e4m3fn[x,y]{0,1} xxx bt = f8e4m3fn[y,x]{1,0} bitcast(a) transpose.1 = f8e4m3fn[x,y]{1,0} transpose(bt), dimensions=(1,0) bt.1= f8e4m3fn[y,x]{0,1} bitcast(transpose.1) custom-call(bt.1,...) ``` Copybara import of the project: -- 237c03240da3dce736d92c8273dc1f9d3be53af5 by shuw <[email protected]>: Improve TransposeMatrix -- 508cd6928bbc20c1d87818eed4ee6190c6c9f691 by Shu Wang <[email protected]>: Fix bug of permutation. -- c55e8a9f64c8dac69907ccebce3b8109ddeb2c48 by shuw <[email protected]>: clang format -- ad0a4ba8054092dd79608865a823c1d432f81b21 by Shu Wang <[email protected]>: Add unittest. -- 1d45b4d64347c64a9483fd26caf7d8598818b855 by Shu Wang <[email protected]>: Remove uncessary space. -- 78378455e70e439e71da078c3099732a14292d7d by Shu Wang <[email protected]>: Update unittest. -- b479c2177672a0010ffba1630efdaec5ca4cee26 by shuw <[email protected]>: Improve TransposeMatrix -- b63318487153a8668b9f95574b054b0129194c0c by Shu Wang <[email protected]>: Update unittest shape and BUILD file. Merging this change closes #17440 PiperOrigin-RevId: 681551009
Hello, we stumbled upon a numerical issue for below modules while training fp8 quantizated models.
This resulted in different numerics and upon checking the cublas runtime thunk - it processed the logical layout correctly and buffer assignment worked exactly the same.
We then had a unit test for testing out tranpose numerics as below
The numerical results of them were 99% different with relative errors > 1e-2.
Could you please help us understand why tranpose to different layout permutation would result in numerical difference? Is the default / non-default layout tranpose a known issue or are we making any unintentional assumptions / mistakes?
The text was updated successfully, but these errors were encountered: