Skip to content

Commit

Permalink
[XLA:GPU] Turn --xla_gpu_triton_gemm_any on by default.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725234232
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Feb 10, 2025
1 parent cd8dc73 commit b39e89e
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
Expand Down Expand Up @@ -720,7 +719,6 @@ INSTANTIATE_TEST_SUITE_P(ParametrizedTritonTest, ParametrizedTritonTest,
I4TestParams::ToString);

TEST_F(TritonTest, NonstandardLayoutWithManyNonContractingDims) {
// We cannot do triton_gemm and we use cuBLAS instead.
constexpr absl::string_view kHloText = R"(
HloModule NonstandardLayoutWithManyNonContractingDims
Expand All @@ -734,7 +732,6 @@ TEST_F(TritonTest, NonstandardLayoutWithManyNonContractingDims) {
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(CHECK: "__cublas$gemm")"));
EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-2}));
}

Expand All @@ -753,7 +750,6 @@ TEST_F(TritonTest, NonstandardLayoutWithManyNonContractingDimsReversedLayout) {
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(CHECK: "__cublas$gemm")"));
EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

Expand Down
2 changes: 1 addition & 1 deletion xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_triton_gemm(true);
opts.set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(false);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(false);
opts.set_xla_gpu_triton_gemm_any(true);
opts.set_xla_gpu_verify_triton_fusion_numerics(false);

// Moving reduce-scatter out of while loops can increase memory footprint, so
Expand Down

0 comments on commit b39e89e

Please sign in to comment.