Skip to content

Commit

Permalink
[XLA:GPU] Disable --xla_gpu_triton_gemm_any on Ampere.
Browse files Browse the repository at this point in the history
Triton's conversion logic from `f16` to `f8e5m2` is wrong pre-Hopper.
Disabling this wholesale is a bit overkill, but easiest---since this flag
flip is what surfaced the issue in the first place.

PiperOrigin-RevId: 725731041
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Feb 11, 2025
1 parent fe39e23 commit 1328fae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,7 @@ cc_library(
"//xla/service/gpu:triton_fusion_analysis",
"//xla/service/gpu:triton_tiling_propagation",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cuda_compute_capability",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
13 changes: 11 additions & 2 deletions xla/service/gpu/transforms/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ limitations under the License.
#include "xla/service/gpu/triton_tiling_propagation.h"
#include "xla/service/instruction_fusion.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -736,6 +737,15 @@ absl::StatusOr<Decision> CreateDotFusion(
}
}

bool should_use_triton_gemm_any =
dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any();

// TODO(b/395903738): Remove this once F16 -> F8E5M2 conversion is fixed.
if (auto* cc = std::get_if<se::CudaComputeCapability>(&gpu_version)) {
should_use_triton_gemm_any =
should_use_triton_gemm_any && cc->IsAtLeastHopper();
}

const PrecisionConfig::Algorithm algorithm =
dot.precision_config().algorithm();
if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
Expand All @@ -744,8 +754,7 @@ absl::StatusOr<Decision> CreateDotFusion(
algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32 ||
algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 ||
algorithm == PrecisionConfig::ALG_DOT_F32_F32_F32 ||
dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() ||
dot.sparse_operands()) {
should_use_triton_gemm_any || dot.sparse_operands()) {
return Decision::Allow();
}

Expand Down

0 comments on commit 1328fae

Please sign in to comment.