Skip to content

Commit

Permalink
PR #22575: [XLA:GPU] Fix triton sparse dot lowering on Blackwell
Browse files Browse the repository at this point in the history
Imported from GitHub PR #22575

Sparse dot is supported for MMA v2 and v3 only, and sm100/sm120 should use MMA v2 (v3 is Hopper-only).
Copybara import of the project:

--
bd4c827 by Sergey Kozub <[email protected]>:

[XLA:GPU] Fix triton sparse dot lowering on Blackwell

Merging this change closes #22575

FUTURE_COPYBARA_INTEGRATE_REVIEW=#22575 from openxla:skozub/sm100_sparse bd4c827
PiperOrigin-RevId: 725966651
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Feb 14, 2025
1 parent a16d96f commit aa8a416
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ class SparseBlockedToMMA : public RewritePattern {
assert(compute_capability_ >= 80 &&
"SparseDot is only supported on Ampere or higher");
bool allow_v3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3");
int version_major = compute_capability_ >= 90 && allow_v3 ? 3 : 2;
// Sparse dot is supported for MMA v2 and v3 only, and sm100/sm120 should
// use MMA v2 (v3 is Hopper-only).
int triton_mma_version = compute_capability_ == 90 && allow_v3 ? 3 : 2;

// get MMA encoding and new return type given the number of warps
auto ret_shape_per_cta = triton::gpu::getShapePerCTA(ret_type);
Expand All @@ -282,13 +284,13 @@ class SparseBlockedToMMA : public RewritePattern {
auto cta_layout = triton::gpu::getCTALayout(ret_type.getEncoding());

auto instr_shape =
mmaVersionToInstrShape(version_major, ret_shape_per_cta,
mmaVersionToInstrShape(triton_mma_version, ret_shape_per_cta,
getElementTypeOrSelf(a.getType()), num_warps);
auto warps_per_tile = mlir::triton::gpu::getWarpsPerTile(
dot_op, ret_shape_per_cta, version_major, num_warps, instr_shape);
NvidiaMmaEncodingAttr mma_enc =
NvidiaMmaEncodingAttr::get(context, version_major, /*versionMinor=*/0,
warps_per_tile, cta_layout, instr_shape);
dot_op, ret_shape_per_cta, triton_mma_version, num_warps, instr_shape);
NvidiaMmaEncodingAttr mma_enc = NvidiaMmaEncodingAttr::get(
context, triton_mma_version, /*versionMinor=*/0, warps_per_tile,
cta_layout, instr_shape);
auto new_ret_type = RankedTensorType::get(
ret_type.getShape(), ret_type.getElementType(), mma_enc);

Expand All @@ -297,7 +299,7 @@ class SparseBlockedToMMA : public RewritePattern {
auto new_acc =
rewriter.create<ConvertLayoutOp>(acc.getLoc(), new_ret_type, acc);

if (version_major == 2) { // MMAV2
if (triton_mma_version == 2) { // MMAV2
int min_bit_width = std::min(triton::gpu::computeOrigBitWidth(a),
triton::gpu::computeOrigBitWidth(b));
int k_width = 32 / min_bit_width;
Expand All @@ -319,7 +321,7 @@ class SparseBlockedToMMA : public RewritePattern {
b = rewriter.create<ConvertLayoutOp>(b.getLoc(), b_type, b);

} else { // MMAV3
assert(version_major == 3 &&
assert(triton_mma_version == 3 &&
"Sparsity is only supported with MMAV2 or higher");
auto elt_type = dot_op.getA().getType().getElementType();
// In MMAV3 transpose is only supported for f16 and bf16.
Expand Down

0 comments on commit aa8a416

Please sign in to comment.