From aa8a416f0d23d44ab8bb0fdebe545c4d3ef7f63e Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Wed, 12 Feb 2025 02:16:42 -0800 Subject: [PATCH] PR #22575: [XLA:GPU] Fix triton sparse dot lowering on Blackwell Imported from GitHub PR https://github.com/openxla/xla/pull/22575 Sparse dot is supported for MMA v2 and v3 only, and sm100/sm120 should use MMA v2 (v3 is Hopper-only). Copybara import of the project: -- bd4c827db0e4adbff629bf0b02d09ff2860e4fb2 by Sergey Kozub : [XLA:GPU] Fix triton sparse dot lowering on Blackwell Merging this change closes #22575 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/22575 from openxla:skozub/sm100_sparse bd4c827db0e4adbff629bf0b02d09ff2860e4fb2 PiperOrigin-RevId: 725966651 --- .../codegen/triton/transforms/sparse_passes.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc b/xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc index 9a228742d0acd..e848b60af7cf6 100644 --- a/xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc +++ b/xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc @@ -273,7 +273,9 @@ class SparseBlockedToMMA : public RewritePattern { assert(compute_capability_ >= 80 && "SparseDot is only supported on Ampere or higher"); bool allow_v3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3"); - int version_major = compute_capability_ >= 90 && allow_v3 ? 3 : 2; + // Sparse dot is supported for MMA v2 and v3 only, and sm100/sm120 should + // use MMA v2 (v3 is Hopper-only). + int triton_mma_version = compute_capability_ == 90 && allow_v3 ? 3 : 2; // get MMA encoding and new return type given the number of warps auto ret_shape_per_cta = triton::gpu::getShapePerCTA(ret_type); @@ -282,13 +284,13 @@ class SparseBlockedToMMA : public RewritePattern { auto cta_layout = triton::gpu::getCTALayout(ret_type.getEncoding()); auto instr_shape = - mmaVersionToInstrShape(version_major, ret_shape_per_cta, + mmaVersionToInstrShape(triton_mma_version, ret_shape_per_cta, getElementTypeOrSelf(a.getType()), num_warps); auto warps_per_tile = mlir::triton::gpu::getWarpsPerTile( - dot_op, ret_shape_per_cta, version_major, num_warps, instr_shape); - NvidiaMmaEncodingAttr mma_enc = - NvidiaMmaEncodingAttr::get(context, version_major, /*versionMinor=*/0, - warps_per_tile, cta_layout, instr_shape); + dot_op, ret_shape_per_cta, triton_mma_version, num_warps, instr_shape); + NvidiaMmaEncodingAttr mma_enc = NvidiaMmaEncodingAttr::get( + context, triton_mma_version, /*versionMinor=*/0, warps_per_tile, + cta_layout, instr_shape); auto new_ret_type = RankedTensorType::get( ret_type.getShape(), ret_type.getElementType(), mma_enc); @@ -297,7 +299,7 @@ class SparseBlockedToMMA : public RewritePattern { auto new_acc = rewriter.create(acc.getLoc(), new_ret_type, acc); - if (version_major == 2) { // MMAV2 + if (triton_mma_version == 2) { // MMAV2 int min_bit_width = std::min(triton::gpu::computeOrigBitWidth(a), triton::gpu::computeOrigBitWidth(b)); int k_width = 32 / min_bit_width; @@ -319,7 +321,7 @@ class SparseBlockedToMMA : public RewritePattern { b = rewriter.create(b.getLoc(), b_type, b); } else { // MMAV3 - assert(version_major == 3 && + assert(triton_mma_version == 3 && "Sparsity is only supported with MMAV2 or higher"); auto elt_type = dot_op.getA().getType().getElementType(); // In MMAV3 transpose is only supported for f16 and bf16.