Skip to content

Commit

Permalink
[XLA:GPU] Add a debug option xla_gpu_unsupported_force_triton_gemm
Browse files Browse the repository at this point in the history
…for use

in tests.

This is to work around issues of test parametrization while `xla_gpu_enable_triton_gemm_any`
needs to be worked around in the main compiler path for A100.

PiperOrigin-RevId: 725949936
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Feb 12, 2025
1 parent 4dbd9af commit c68d132
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 13 deletions.
8 changes: 4 additions & 4 deletions xla/backends/gpu/codegen/triton/dot_algorithms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class AlgorithmTest : public GpuCodegenTest {
debug_options.set_xla_gpu_dump_autotuned_gemm_fusions(true);

// Enable triton fusion for all supported GEMMs.
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);

return debug_options;
}
Expand Down Expand Up @@ -156,7 +156,7 @@ class TritonAlgorithmTest : public AlgorithmTest {
// Do not fall back to cuBLAS, we are testing Triton.
debug_options.set_xla_gpu_cublas_fallback(false);
// Enable gemm for any hlo including pure matmuls.
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);
// Do not autotune split-k by default, since this prevents deterministically
// matching the optimized HLO.
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
Expand Down Expand Up @@ -564,7 +564,7 @@ class Triton3xBF16GemmTest : public AlgorithmTest {
// to be on the safe side against future flakiness.
//
// Enable triton fusion for all supported GEMMs.
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);
// Do not fall back to cuBLAS, we are testing Triton.
debug_options.set_xla_gpu_cublas_fallback(false);

Expand Down Expand Up @@ -1318,7 +1318,7 @@ class TritonAndBlasSupportForDifferentTensorSizes
debug_options_ = GetDebugOptionsForTest();

triton_options_ = debug_options_;
triton_options_.set_xla_gpu_triton_gemm_any(true);
triton_options_.set_xla_gpu_unsupported_force_triton_gemm(true);
triton_options_.set_xla_gpu_cublas_fallback(false);

blas_options_ = debug_options_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ class TritonGemmTestAny : public TritonGemmTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);
return debug_options;
}
};
Expand Down Expand Up @@ -4089,7 +4089,7 @@ class TritonGemmContractionDims : public TritonGemmTest {
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_ensure_minor_dot_contraction_dims(true);
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);

return debug_options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class TritonTest : public GpuCodegenTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);
debug_options.set_xla_gpu_cublas_fallback(false);
// Always rewrite Gemms with Triton regardless of size.
debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
Expand Down
1 change: 1 addition & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_unsupported_enable_generic_triton_emitter_for_gemms(false);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(true);
opts.set_xla_gpu_unsupported_force_triton_gemm(false);
opts.set_xla_gpu_verify_triton_fusion_numerics(false);

// Moving reduce-scatter out of while loops can increase memory footprint, so
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2065,7 +2065,7 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id,
DebugOptions& debug_options =
*compile_options.executable_build_options.mutable_debug_options();
debug_options.set_xla_gpu_shard_autotuning(true);
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);
debug_options.set_xla_gpu_cublas_fallback(false);

if (node_id < num_nodes_using_cache) {
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/float_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class FloatSupportTestWithTriton : public FloatSupportTest {
DebugOptions GetDebugOptionsForTest() const override {
DebugOptions debug_options = FloatSupportTest::GetDebugOptionsForTest();
debug_options.set_xla_gpu_enable_triton_gemm(true);
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);
debug_options.set_xla_gpu_cublas_fallback(false);
return debug_options;
}
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/tests/tensor_float_32_global_var_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TensorFloat32GlobalVarTest : public ::testing::WithParamInterface<bool>,
const bool enable_triton_gemm = GetParam();
if (enable_triton_gemm) {
debug_options.set_xla_gpu_enable_triton_gemm(true);
debug_options.set_xla_gpu_triton_gemm_any(true);
debug_options.set_xla_gpu_unsupported_force_triton_gemm(true);
debug_options.set_xla_gpu_cublas_fallback(false);
} else {
debug_options.set_xla_gpu_enable_triton_gemm(false);
Expand Down
8 changes: 6 additions & 2 deletions xla/service/gpu/transforms/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -737,15 +737,19 @@ absl::StatusOr<Decision> CreateDotFusion(
}
}

bool should_use_triton_gemm_any =
dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any();
const DebugOptions& debug_options = dot.GetModule()->config().debug_options();
bool should_use_triton_gemm_any = debug_options.xla_gpu_triton_gemm_any();

// TODO(b/395903738): Remove this once F16 -> F8E5M2 conversion is fixed.
if (auto* cc = std::get_if<se::CudaComputeCapability>(&gpu_version)) {
should_use_triton_gemm_any =
should_use_triton_gemm_any && cc->IsAtLeastHopper();
}

should_use_triton_gemm_any =
should_use_triton_gemm_any ||
debug_options.xla_gpu_unsupported_force_triton_gemm();

const PrecisionConfig::Algorithm algorithm =
dot.precision_config().algorithm();
if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
Expand Down
8 changes: 7 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,12 @@ message DebugOptions {
// Internal debug/testing flag to switch Triton GEMM fusions on or off.
bool xla_gpu_unsupported_enable_triton_gemm = 322;

// Internal debug/testing flag to force all GEMMs to use Triton, independently
// of known issues.
// TODO(b/395903738): use to make specific tests pass on A100 while working
// around this bug. The can be removed once the bug is fixed.
bool xla_gpu_unsupported_force_triton_gemm = 369;

// This instructs the runtime whether to use memcpy for p2p communication when
// source and target are located within a node(nvlink).
bool xla_gpu_use_memcpy_local_p2p = 287;
Expand Down Expand Up @@ -1152,7 +1158,7 @@ message DebugOptions {

// Note: when adding a new flag, please add it to one of the hardware-specific
// or hardware-agnostic sections at the top of this proto message.
// Next id: 369
// Next id: 370

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit c68d132

Please sign in to comment.