-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a few related optimization passes for fp8 gemm custom-calls. #16975
Conversation
Is it possible to add tests? Do we have an explanation why those passes are required for correctness? |
Let me provide more context: We are migrating from our original fake-quantization-like FP8 pattern to a new direct-quantization approach, where DQ scaling is applied as the epilogue of the dot operation. This change allows us to avoid concerns about other XLA optimizer passes breaking our patterns, as the dot can now directly handle FP8 inputs in direct-quantization and it will always be lowered to call fp8 gemm. During the migration, @elfiegg found that wheh Triton GEMM falls back to the GEMM rewriter for cuBLAS, these three specific passes, i.e. LayoutNormalization, GpuAlgebraicSimplifier, and ScatterSimplifier, are necessary to ensure correctness. I think she is still investigating for the specific reason and looking for the unit test. |
OK thanks for the context! |
Thanks Kaixi for bringing everyone on the same page! Also sorry for the delay. As mentioned, during debugging, we found that layout normalization is crucial for ensuring numerical correctness. We consistently reproduced the numerical issue when using different operand layout permutations with cuBLAS - operands such as I've also added a unit test to ensure numerical correctness in the pipeline, both with Triton fusion falling back to cuBLAS and without Triton fusion(the test will fail without the changes in this PR.) |
0662649
to
ac06ecf
Compare
Removed a duplicate GpuAlgebraicSimplifier pass that is already in place in a later stage. Upon further investigation, it seems the cublas runtime thunk correctly processed the layout, and the final MatrixLayout of both operands and buffer assignments are identical for the cublasLt custom calls in both cases: f8e4m3fn[12288,4096]{1,0} and f8e4m3fn[4096,12288]{0,1}. Despite this, cublasLt produces different numerical results for the two calls. It's unclear what the root cause is. Since LayoutNormalization has been used in previous cublasLT FP8 GEMM calls without Triton fp8 gemm, I believe it's safe to proceed with adding this pass for now, pending further investigation. |
Further investigation shows transpose seems broken for fp8 operands as following modules generate different numerics, for which I'll file another bug for tracking.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why this is necessary. It seems the core issue is that certain layouts have incorrect numerics. That said, I'm ok taking this for now if it does fix the layout issue, but in the long term it's better to directly ensure different layouts still have correct numerics.
Does LayoutNormalization even affect a cublas gemm custom call? I see we pass NormalizeLayoutForGpuCustomCalls
, but that only affects convolutions, not gemms.
xla/service/gpu/gpu_compiler_test.cc
Outdated
|
||
HloModuleConfig config; | ||
DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); | ||
triton_enabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why disable dynamic slice fusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No shame in copy-pasting from the other test ;:^^)
xla/service/gpu/gpu_compiler_test.cc
Outdated
config.set_replica_count(1); | ||
config.set_num_partitions(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to set these, as they are the defaults
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
xla/service/gpu/gpu_compiler_test.cc
Outdated
// Load autotuning DB. We shouldn't depend on actual execution times in a unit | ||
// test. | ||
std::string path = | ||
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", | ||
"gpu_compiler_test_autotune_db.textproto"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not load autotune results, but instead disable the cublas fallback for the triton case by calling triton_enabled_debug_options.set_xla_gpu_cublas_fallback(false)
. That way you don't have to make sure gpu_compiler_test_autotune_db.textproto
has the exact gemm config that this HLO generates. Maybe disable autotuning as well if you want the outputs to be deterministic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah let me clarify - the test is comparing triton enabled but falling back to cublasLT vs. triton disabled paths. Regarding that do you have any suggestion / concern?
Renamed the test to be clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, and the test name clarifies things. I forgot that GemmRewriter is called twice: Once to handle gemms that the Triton rewriter didn't handle, and again to handle formerly-Triton fusions that the autotuner decided to use cublas for. So good to test the two cases are equivalent numerically.
@elfie just filed an issue to further narrow down the issue, hopefully. Here's the current understanding: We’ve learned that the GEMM rewriter inserts a transpose for one operand as follows:
However, directly running this results in incorrect outputs. Upon investigation, we found that the layout normalization pass inserts a bitcast to ensure the following pattern works:
This modification produces the correct results. Given this, we are curious if there is a usage restriction that the transpose must operate over a |
b85df42
to
90f5968
Compare
Thanks again Kaixi for helping bridge the communication gaps.
Regarding this, I had the same confusion. It seems that the pipeline may still use legacy naming for historical reasons, which we should consider updating. However, as mentioned, the pass is indeed normalizing non-default layouts of transform instructions to the default layout. We have also observed numerical differences in the execution results after this normalization. |
xla/service/gpu/gpu_compiler_test.cc
Outdated
// Load autotuning DB. We shouldn't depend on actual execution times in a unit | ||
// test. | ||
std::string path = | ||
tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", | ||
"gpu_compiler_test_autotune_db.textproto"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, and the test name clarifies things. I forgot that GemmRewriter is called twice: Once to handle gemms that the Triton rewriter didn't handle, and again to handle formerly-Triton fusions that the autotuner decided to use cublas for. So good to test the two cases are equivalent numerically.
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 90f5968 by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 90f5968 PiperOrigin-RevId: 675755585
…calls. Imported from GitHub PR openxla/xla#16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 90f596851f20459e37b713a10283499658ebf41e by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 90f596851f20459e37b713a10283499658ebf41e PiperOrigin-RevId: 675755585
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 90f5968 by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 90f5968 PiperOrigin-RevId: 675755585
…calls. Imported from GitHub PR openxla/xla#16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 90f596851f20459e37b713a10283499658ebf41e by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 Reverts 23b5e27 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 90f596851f20459e37b713a10283499658ebf41e PiperOrigin-RevId: 675755585
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 90f5968 by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 90f5968 PiperOrigin-RevId: 675755585
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c PiperOrigin-RevId: 684532401
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c PiperOrigin-RevId: 684532401
…calls. Imported from GitHub PR openxla/xla#16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c8667792fe9ed189ab55308ca6e83859d4 by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4 PiperOrigin-RevId: 684532401
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c PiperOrigin-RevId: 684532401
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c PiperOrigin-RevId: 684532401
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c PiperOrigin-RevId: 684826820
…calls. Imported from GitHub PR openxla/xla#16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c8667792fe9ed189ab55308ca6e83859d4 by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4 PiperOrigin-RevId: 684826820
…calls. Imported from GitHub PR #16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16975 from elfiegg:pass 81af29c PiperOrigin-RevId: 684826820
…calls. Imported from GitHub PR openxla/xla#16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c8667792fe9ed189ab55308ca6e83859d4 by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4 PiperOrigin-RevId: 684826820
…calls. Imported from GitHub PR openxla/xla#16975 This caused convergence issue for fp8 training, tested on GPT3 models: Before: ``` NETWORK BACKEND MATH SDPA XLA_EXTRAS GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.064 11.019 1571 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 11.015041 [PAX STATUS] step_i: 200, training loss: 11.016165 [PAX STATUS] step_i: 300, training loss: 11.016386 [PAX STATUS] step_i: 400, training loss: 11.014653 [PAX STATUS] step_i: 500, training loss: 11.014734 [PAX STATUS] step_i: 600, training loss: 11.01613 [PAX STATUS] step_i: 700, training loss: 11.009399 [PAX STATUS] step_i: 800, training loss: 11.017071 [PAX STATUS] step_i: 900, training loss: 11.014582 [PAX STATUS] step_i: 1000, training loss: 11.013434 [PAX STATUS] step_i: 1100, training loss: 11.021271 [PAX STATUS] step_i: 1200, training loss: 11.008364 [PAX STATUS] step_i: 1300, training loss: 11.0198145 [PAX STATUS] step_i: 1400, training loss: 11.01253 [PAX STATUS] step_i: 1500, training loss: 11.019016 ``` After: ``` NETWORK BACKEND MATH SDPA GPUs STEPS/SEC LOSS WALLSECS GPT5B XLA fp8 FA 8 1.020 3.797 1647 [PAX STATUS]: Starting training loop. [PAX STATUS] step_i: 100, training loss: 6.150083 [PAX STATUS] step_i: 200, training loss: 5.8871064 [PAX STATUS] step_i: 300, training loss: 5.4491887 [PAX STATUS] step_i: 400, training loss: 5.6384015 [PAX STATUS] step_i: 500, training loss: 5.273538 [PAX STATUS] step_i: 600, training loss: 5.2011905 [PAX STATUS] step_i: 700, training loss: 4.903013 [PAX STATUS] step_i: 800, training loss: 4.62972 [PAX STATUS] step_i: 900, training loss: 4.507727 [PAX STATUS] step_i: 1000, training loss: 4.625259 [PAX STATUS] step_i: 1100, training loss: 4.428066 [PAX STATUS] step_i: 1200, training loss: 4.252451 [PAX STATUS] step_i: 1300, training loss: 3.8448389 [PAX STATUS] step_i: 1400, training loss: 3.8578327 [PAX STATUS] step_i: 1500, training loss: 3.796958 ``` Copybara import of the project: -- 81af29c8667792fe9ed189ab55308ca6e83859d4 by Elfie Guo <[email protected]>: Add a few related optimization pass for fp8 gemm rerwriter. Merging this change closes #16975 PiperOrigin-RevId: 685133984
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4 PiperOrigin-RevId: 685029222
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16975 from elfiegg:pass 81af29c8667792fe9ed189ab55308ca6e83859d4 PiperOrigin-RevId: 684256520
This PR was rolled back in fd64718! |
Currently investigating if I can provide a fix and roll forward. |
…ization passes for fp8 gemm custom-calls. Reverts fd64718 PiperOrigin-RevId: 686037932
Created fix in #18342. |
…ization passes for fp8 gemm custom-calls. Reverts fd64718 PiperOrigin-RevId: 686037932
…ization passes for fp8 gemm custom-calls. Reverts fd64718 PiperOrigin-RevId: 686076980
This caused convergence issue for fp8 training, tested on GPT3 models:
Before:
After: