Skip to content

Commit

Permalink
PR #22575: [XLA:GPU] Fix triton sparse dot lowering on Blackwell
Browse files Browse the repository at this point in the history
Imported from GitHub PR #22575

Sparse dot is supported for MMA v2 and v3 only, and sm100/sm120 should use MMA v2 (v3 is Hopper-only).
Copybara import of the project:

--
bd4c827 by Sergey Kozub <[email protected]>:

[XLA:GPU] Fix triton sparse dot lowering on Blackwell

Merging this change closes #22575

FUTURE_COPYBARA_INTEGRATE_REVIEW=#22575 from openxla:skozub/sm100_sparse bd4c827
PiperOrigin-RevId: 725966651
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Feb 14, 2025
1 parent a16d96f commit f696311
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 36 deletions.
18 changes: 10 additions & 8 deletions xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ class SparseBlockedToMMA : public RewritePattern {
assert(compute_capability_ >= 80 &&
"SparseDot is only supported on Ampere or higher");
bool allow_v3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3");
int version_major = compute_capability_ >= 90 && allow_v3 ? 3 : 2;
// Sparse dot is supported for MMA v2 and v3 only, and sm100/sm120 should
// use MMA v2 (v3 is Hopper-only).
int triton_mma_version = compute_capability_ == 90 && allow_v3 ? 3 : 2;

// get MMA encoding and new return type given the number of warps
auto ret_shape_per_cta = triton::gpu::getShapePerCTA(ret_type);
Expand All @@ -282,13 +284,13 @@ class SparseBlockedToMMA : public RewritePattern {
auto cta_layout = triton::gpu::getCTALayout(ret_type.getEncoding());

auto instr_shape =
mmaVersionToInstrShape(version_major, ret_shape_per_cta,
mmaVersionToInstrShape(triton_mma_version, ret_shape_per_cta,
getElementTypeOrSelf(a.getType()), num_warps);
auto warps_per_tile = mlir::triton::gpu::getWarpsPerTile(
dot_op, ret_shape_per_cta, version_major, num_warps, instr_shape);
NvidiaMmaEncodingAttr mma_enc =
NvidiaMmaEncodingAttr::get(context, version_major, /*versionMinor=*/0,
warps_per_tile, cta_layout, instr_shape);
dot_op, ret_shape_per_cta, triton_mma_version, num_warps, instr_shape);
NvidiaMmaEncodingAttr mma_enc = NvidiaMmaEncodingAttr::get(
context, triton_mma_version, /*versionMinor=*/0, warps_per_tile,
cta_layout, instr_shape);
auto new_ret_type = RankedTensorType::get(
ret_type.getShape(), ret_type.getElementType(), mma_enc);

Expand All @@ -297,7 +299,7 @@ class SparseBlockedToMMA : public RewritePattern {
auto new_acc =
rewriter.create<ConvertLayoutOp>(acc.getLoc(), new_ret_type, acc);

if (version_major == 2) { // MMAV2
if (triton_mma_version == 2) { // MMAV2
int min_bit_width = std::min(triton::gpu::computeOrigBitWidth(a),
triton::gpu::computeOrigBitWidth(b));
int k_width = 32 / min_bit_width;
Expand All @@ -319,7 +321,7 @@ class SparseBlockedToMMA : public RewritePattern {
b = rewriter.create<ConvertLayoutOp>(b.getLoc(), b_type, b);

} else { // MMAV3
assert(version_major == 3 &&
assert(triton_mma_version == 3 &&
"Sparsity is only supported with MMAV2 or higher");
auto elt_type = dot_op.getA().getType().getElementType();
// In MMAV3 transpose is only supported for f16 and bf16.
Expand Down
50 changes: 22 additions & 28 deletions xla/hlo/transforms/collectives/all_reduce_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ limitations under the License.

#include "xla/hlo/transforms/collectives/all_reduce_combiner.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -118,9 +120,8 @@ TEST_F(AllReduceCombinerTest, CombineAllReduces) {
// Run the AllReduce combiner optimization pass.
AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
ASSERT_EQ(AllReduceCount(*module), 1);
EXPECT_TRUE(changed);

ASSERT_EQ(root, computation->root_instruction());
ASSERT_EQ(inputs.size(), root->operands().size());
Expand Down Expand Up @@ -166,10 +167,9 @@ TEST_F(AllReduceCombinerTest, CombineCrossReplicaReductionsInGroups) {
// Run the AllReduce combiner optimization pass.
AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
ASSERT_EQ(AllReduceCount(*module), 3)
<< "expects 3 groups for 3 reduction types.";
EXPECT_TRUE(changed);
}

// Tests that the combination threshold is respected.
Expand All @@ -188,19 +188,17 @@ TEST_F(AllReduceCombinerTest, RespectThreshold) {
{
AllReduceCombiner combine((8 + 4) * 1024 - 1, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), inputs.size());
EXPECT_FALSE(changed);
}

// Run the AllReduce combiner optimization pass again with a slightly
// higher threshold so that the combination can occur.
{
AllReduceCombiner combine((8 + 4) * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
EXPECT_EQ(AllReduceCount(*module), 1);
EXPECT_TRUE(changed);
}
}

Expand All @@ -226,9 +224,8 @@ TEST_F(AllReduceCombinerTest, NoDependentCombination) {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}

// Tests that AllReduce ops with different groups are not combined.
Expand All @@ -255,9 +252,8 @@ TEST_F(AllReduceCombinerTest, GroupAllReduce) {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}

TEST_F(AllReduceCombinerTest, DomainPreventsCombining) {
Expand All @@ -278,9 +274,11 @@ ENTRY entry {
crs1 = f32[128] all-reduce(param1),
replica_groups={}, to_apply=summit, sharding={maximal device=1}
domain0 = f32[128] domain(crs0),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=0}}
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}},
exit={maximal device=0}}
domain1 = f32[128] domain(crs1),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=1}}
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}},
exit={maximal device=1}}
ROOT tuple = (f32[128], f32[128]) tuple(domain0, domain1),
sharding={{maximal device=0}, {maximal device=1}}
}
Expand All @@ -291,9 +289,8 @@ ENTRY entry {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}

// This test checks that two CRS instructions that are in separate domains
Expand Down Expand Up @@ -336,9 +333,8 @@ ENTRY entry {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 3);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_TRUE(changed);

// Verify that the sharding is combined correctly.
const HloInstruction* param0 =
Expand Down Expand Up @@ -375,9 +371,8 @@ ENTRY entry {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(false));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}

TEST_F(AllReduceCombinerTest, DoNotCombineWithControlDependencies) {
Expand Down Expand Up @@ -453,9 +448,8 @@ ENTRY entry {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 4);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_TRUE(changed);

EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Add(op::Domain(op::GetTupleElement(AllOf(
Expand Down Expand Up @@ -501,9 +495,8 @@ ENTRY %comp {

AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 6);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
EXPECT_EQ(AllReduceCount(*module), 4);
EXPECT_TRUE(changed);

auto crs0 = op::AllReduce(op::Parameter(0), op::AllReduce(op::Parameter(1)));
auto add = op::Add(op::AllReduce(op::GetTupleElement(crs0, 0)),
Expand All @@ -527,16 +520,17 @@ TEST_F(AllReduceCombinerTest, PreservesMetadata) {
ENTRY entry {
%param.0 = f32[32] parameter(0)
%param.1 = f32[32] parameter(1)
%all-reduce.0 = f32[32] all-reduce(%param.0), replica_groups={}, to_apply=%add, metadata={op_type="test_type0" op_name="test_name0"}
%all-reduce.1 = f32[32] all-reduce(%param.1), replica_groups={}, to_apply=%add, metadata={op_type="test_type1" op_name="test_name1"}
%all-reduce.0 = f32[32] all-reduce(%param.0), replica_groups={},
to_apply=%add, metadata={op_type="test_type0" op_name="test_name0"}
%all-reduce.1 = f32[32] all-reduce(%param.1), replica_groups={},
to_apply=%add, metadata={op_type="test_type1" op_name="test_name1"}
ROOT tuple = (f32[32], f32[32]) tuple(%all-reduce.0, %all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(combine.Run(module.get()), IsOkAndHolds(true));
OpMetadata metadata;
metadata.set_op_type("test_type0");
metadata.set_op_name("test_name0");
Expand Down

0 comments on commit f696311

Please sign in to comment.