Skip to content

Commit

Permalink
PR #22575: [XLA:GPU] Fix triton sparse dot lowering on Blackwell
Browse files Browse the repository at this point in the history
Imported from GitHub PR #22575

Sparse dot is supported for MMA v2 and v3 only, and sm100/sm120 should use MMA v2 (v3 is Hopper-only).
Copybara import of the project:

--
bd4c827 by Sergey Kozub <[email protected]>:

[XLA:GPU] Fix triton sparse dot lowering on Blackwell

Merging this change closes #22575

FUTURE_COPYBARA_INTEGRATE_REVIEW=#22575 from openxla:skozub/sm100_sparse bd4c827
PiperOrigin-RevId: 725966651
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Feb 14, 2025
1 parent 8ca669b commit e999cb1
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 59 deletions.
18 changes: 10 additions & 8 deletions xla/backends/gpu/codegen/triton/transforms/sparse_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ class SparseBlockedToMMA : public RewritePattern {
assert(compute_capability_ >= 80 &&
"SparseDot is only supported on Ampere or higher");
bool allow_v3 = !triton::tools::getBoolEnv("DISABLE_MMA_V3");
int version_major = compute_capability_ >= 90 && allow_v3 ? 3 : 2;
// Sparse dot is supported for MMA v2 and v3 only, and sm100/sm120 should
// use MMA v2 (v3 is Hopper-only).
int triton_mma_version = compute_capability_ == 90 && allow_v3 ? 3 : 2;

// get MMA encoding and new return type given the number of warps
auto ret_shape_per_cta = triton::gpu::getShapePerCTA(ret_type);
Expand All @@ -282,13 +284,13 @@ class SparseBlockedToMMA : public RewritePattern {
auto cta_layout = triton::gpu::getCTALayout(ret_type.getEncoding());

auto instr_shape =
mmaVersionToInstrShape(version_major, ret_shape_per_cta,
mmaVersionToInstrShape(triton_mma_version, ret_shape_per_cta,
getElementTypeOrSelf(a.getType()), num_warps);
auto warps_per_tile = mlir::triton::gpu::getWarpsPerTile(
dot_op, ret_shape_per_cta, version_major, num_warps, instr_shape);
NvidiaMmaEncodingAttr mma_enc =
NvidiaMmaEncodingAttr::get(context, version_major, /*versionMinor=*/0,
warps_per_tile, cta_layout, instr_shape);
dot_op, ret_shape_per_cta, triton_mma_version, num_warps, instr_shape);
NvidiaMmaEncodingAttr mma_enc = NvidiaMmaEncodingAttr::get(
context, triton_mma_version, /*versionMinor=*/0, warps_per_tile,
cta_layout, instr_shape);
auto new_ret_type = RankedTensorType::get(
ret_type.getShape(), ret_type.getElementType(), mma_enc);

Expand All @@ -297,7 +299,7 @@ class SparseBlockedToMMA : public RewritePattern {
auto new_acc =
rewriter.create<ConvertLayoutOp>(acc.getLoc(), new_ret_type, acc);

if (version_major == 2) { // MMAV2
if (triton_mma_version == 2) { // MMAV2
int min_bit_width = std::min(triton::gpu::computeOrigBitWidth(a),
triton::gpu::computeOrigBitWidth(b));
int k_width = 32 / min_bit_width;
Expand All @@ -319,7 +321,7 @@ class SparseBlockedToMMA : public RewritePattern {
b = rewriter.create<ConvertLayoutOp>(b.getLoc(), b_type, b);

} else { // MMAV3
assert(version_major == 3 &&
assert(triton_mma_version == 3 &&
"Sparsity is only supported with MMAV2 or higher");
auto elt_type = dot_op.getA().getType().getElementType();
// In MMAV3 transpose is only supported for f16 and bf16.
Expand Down
14 changes: 12 additions & 2 deletions xla/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ load(
"if_google",
"tsl_gpu_library",
)
load("//xla/tsl:tsl.default.bzl", "filegroup")
load("//xla/tsl:tsl.default.bzl", "filegroup", "tsl_pybind_extension")
load(
"//xla/tsl/platform:build_config.bzl",
"tf_proto_library",
Expand Down Expand Up @@ -716,6 +716,7 @@ cc_library(
"//xla/pjrt:pjrt_client",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:backend",
"//xla/service:hlo_module_config",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/service/gpu/model:hlo_op_profiler_lib",
"//xla/service/gpu/model:hlo_op_profiles",
Expand Down Expand Up @@ -759,14 +760,23 @@ cc_library(
alwayslink = True,
)

tsl_pybind_extension(
name = "collective_perf_table_gen_bindings",
srcs = ["collective_perf_table_gen_bindings.cc"],
deps = [
":collective_perf_table_gen",
"@com_google_absl//absl/log:check",
"@nanobind",
],
)

xla_test(
name = "collective_perf_table_gen_test",
srcs = ["collective_perf_table_gen_test.cc"],
backends = ["gpu"],
local_defines = if_cuda(["GOOGLE_CUDA"]),
deps = [
":collective_perf_table_gen",
"//xla/hlo/ir:hlo",
"//xla/service/gpu/model:hlo_op_profile_proto_cc",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/tests:hlo_test_base",
Expand Down
53 changes: 43 additions & 10 deletions xla/tools/collective_perf_table_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/time/time.h"
#include "xla/hlo/ir/collective_device_list.h"
Expand All @@ -38,6 +39,7 @@ limitations under the License.
#include "xla/service/gpu/model/hlo_op_profile.pb.h"
#include "xla/service/gpu/model/hlo_op_profiler.h"
#include "xla/service/gpu/model/hlo_op_profiles.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tools/multihost_hlo_runner/create_client.h"
#include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h"
#include "xla/tsl/platform/env.h"
Expand Down Expand Up @@ -186,7 +188,10 @@ std::unique_ptr<HloModule> CreateCollectiveModule(const StaticSpec& spec) {
std::string hlo =
GetHlo(spec.collective_type, input_dim, output_dim, spec.replica_groups);

auto parsed = ParseAndReturnUnverifiedModule(hlo);
HloModuleConfig config;
config.set_num_partitions(spec.replica_groups.num_devices_per_group() *
spec.replica_groups.num_replica_groups());
auto parsed = ParseAndReturnUnverifiedModule(hlo, config);
CHECK_OK(parsed.status());
return std::move(*parsed);
}
Expand All @@ -202,6 +207,16 @@ uint64_t GetNetworkThroughputBytesPerSec(absl::Duration runtime,
return tensor_size_bytes * 1e9 / absl::ToInt64Nanoseconds(runtime);
}

IotaReplicaGroupList GetCollectiveDeviceList(
absl::string_view collective_device_list_unparsed) {
auto collective_device_list =
xla::ParseCollectiveDeviceListOnly(collective_device_list_unparsed);
CHECK_OK(collective_device_list);
CHECK(collective_device_list->iota_replica_group_list().has_value());

return *collective_device_list->iota_replica_group_list();
}

} // namespace

/*static*/ std::unique_ptr<CollectivePerfTableGen>
Expand All @@ -224,7 +239,7 @@ std::unique_ptr<PjRtLoadedExecutable> CollectivePerfTableGen::Compile(
std::unique_ptr<HloModule> module) {
DebugOptions debug_opts;
FunctionalHloRunner::RawCompileOptions opts;
opts.num_partitions = 8;
opts.num_partitions = module->config().num_partitions();
opts.spmd_mode = FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning;
auto compile_opts = FunctionalHloRunner::CreateCompileOptions(
*pjrt_env_.client, opts, config_.task_id, config_.num_nodes);
Expand All @@ -250,18 +265,26 @@ CollectivePerfTableGen::ProfilingData CollectivePerfTableGen::Profile(
VLOG(1) << "Compiled module: "
<< executable->GetHloModules().value()[0]->ToString();

// We do not profile dry runs or on more than one tasks.
if (config_.dry_run) {
return {};
}

std::unique_ptr<HloOpProfiler::KernelTracer> tracer =
HloOpProfiler::GetKernelTracer();
if (config_.task_id == 0) {
std::unique_ptr<HloOpProfiler::KernelTracer> tracer =
HloOpProfiler::GetKernelTracer();
for (int i = 0; i < kNumProfilingRuns; ++i) {
Run(*executable);
}
return {
/*runtime=*/absl::Nanoseconds(
std::move(*tracer).getMedianKernelTimeNs()),
};
}
for (int i = 0; i < kNumProfilingRuns; ++i) {
Run(*executable);
}
return {
/*runtime=*/absl::Nanoseconds(std::move(*tracer).getMedianKernelTimeNs()),
};
return {};
}

DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
Expand All @@ -282,11 +305,14 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
for (int64_t tensor_size = tsize_spec.start; tensor_size <= tsize_spec.stop;
tensor_size = inc(tensor_size, tsize_spec)) {
for (CollectiveType collective_type : config_.collective_types) {
for (const IotaReplicaGroupList& replica_groups :
config_.replica_groups_list) {
for (absl::string_view replica_groups_raw : config_.replica_groups_list) {
CHECK(collective_type != CollectiveType::UNSPECIFIED);

StaticSpec spec{collective_type, replica_groups, tensor_size};
StaticSpec spec{
collective_type,
GetCollectiveDeviceList(replica_groups_raw),
tensor_size,
};
static_specs.push_back(spec);
}
}
Expand All @@ -313,6 +339,10 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {
if (!config_.dry_run) {
profiled_data = Profile(std::move(spec.module));
}
if (profiled_data.runtime == absl::ZeroDuration()) {
VLOG(1) << "Size: " << static_spec.tensor_size_bytes << " too small.";
continue;
}
entry.set_network_throughput_bytes_per_sec(GetNetworkThroughputBytesPerSec(
profiled_data.runtime, static_spec.tensor_size_bytes));

Expand All @@ -332,6 +362,9 @@ DeviceHloInstructionProfiles CollectivePerfTableGen::ComputeTable() {

absl::Status CollectivePerfTableGen::Dump(
const DeviceHloInstructionProfiles& table) {
if (config_.task_id != 0) {
return absl::OkStatus();
}
if (config_.output == CollectivePerfTableGen::Config::kStdout) {
LOG(INFO) << table.DebugString();
return absl::OkStatus();
Expand Down
17 changes: 10 additions & 7 deletions xla/tools/collective_perf_table_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "xla/hlo/ir/collective_device_list.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/service/backend.h"
Expand All @@ -43,10 +42,10 @@ namespace xla::gpu {
class CollectivePerfTableGen {
public:
struct StepSpec {
int64_t start = 0;
int64_t stop = -1;
int64_t start = 1024;
int64_t stop = 2ll * 1024 * 1024 * 1024;
int64_t step = 0;
int64_t factor = 0;
int64_t factor = 2;
};

enum class CollectiveType {
Expand All @@ -61,14 +60,18 @@ class CollectivePerfTableGen {

// Search space.
StepSpec tensor_size_bytes_spec;
std::vector<CollectiveType> collective_types;
std::vector<IotaReplicaGroupList> replica_groups_list;
std::vector<CollectiveType> collective_types = {
CollectiveType::ALL_REDUCE,
CollectiveType::ALL_GATHER,
CollectiveType::REDUCE_SCATTER,
};
std::vector<std::string> replica_groups_list;

// Execution opts.
bool dry_run = false;
std::string output = std::string(kStdout);
std::string coordinator_address = "";
absl::Duration connection_timeout = absl::Seconds(60);
absl::Duration connection_timeout = absl::Seconds(600);
uint16_t num_nodes = 1;
uint16_t task_id = 0;
};
Expand Down
74 changes: 74 additions & 0 deletions xla/tools/collective_perf_table_gen_bindings.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>

#include "absl/log/check.h"
#include "nanobind/nanobind.h"
#include "nanobind/nb_defs.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "xla/tools/collective_perf_table_gen.h"

namespace nb = nanobind;

NB_MODULE(collective_perf_table_gen_bindings, m) {
// Bind the Config struct
nb::class_<xla::gpu::CollectivePerfTableGen::Config>(m, "Config")
.def(nb::init<>())
.def_rw("tensor_size_bytes_spec",
&xla::gpu::CollectivePerfTableGen::Config::tensor_size_bytes_spec)
.def_rw("collective_types",
&xla::gpu::CollectivePerfTableGen::Config::collective_types)
.def_rw("replica_groups_list",
&xla::gpu::CollectivePerfTableGen::Config::replica_groups_list)
.def_rw("dry_run", &xla::gpu::CollectivePerfTableGen::Config::dry_run)
.def_rw("output", &xla::gpu::CollectivePerfTableGen::Config::output)
.def_rw("coordinator_address",
&xla::gpu::CollectivePerfTableGen::Config::coordinator_address)
.def_rw("connection_timeout",
&xla::gpu::CollectivePerfTableGen::Config::connection_timeout)
.def_rw("num_nodes", &xla::gpu::CollectivePerfTableGen::Config::num_nodes)
.def_rw("task_id", &xla::gpu::CollectivePerfTableGen::Config::task_id);

// Bind the StepSpec struct
nb::class_<xla::gpu::CollectivePerfTableGen::StepSpec>(m, "StepSpec")
.def(nb::init<>())
.def_rw("start", &xla::gpu::CollectivePerfTableGen::StepSpec::start)
.def_rw("stop", &xla::gpu::CollectivePerfTableGen::StepSpec::stop)
.def_rw("step", &xla::gpu::CollectivePerfTableGen::StepSpec::step)
.def_rw("factor", &xla::gpu::CollectivePerfTableGen::StepSpec::factor);

// Bind the CollectiveType enum
nb::enum_<xla::gpu::CollectivePerfTableGen::CollectiveType>(m,
"CollectiveType")
.value("UNSPECIFIED",
xla::gpu::CollectivePerfTableGen::CollectiveType::UNSPECIFIED)
.value("ALL_REDUCE",
xla::gpu::CollectivePerfTableGen::CollectiveType::ALL_REDUCE)
.value("REDUCE_SCATTER",
xla::gpu::CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER)
.value("ALL_GATHER",
xla::gpu::CollectivePerfTableGen::CollectiveType::ALL_GATHER)
.export_values();

m.def("run", [](xla::gpu::CollectivePerfTableGen::Config config) -> void {
std::unique_ptr<xla::gpu::CollectivePerfTableGen> gen =
xla::gpu::CollectivePerfTableGen::Create(config);
auto table = gen->ComputeTable();
CHECK_OK(gen->Dump(table));
});
}
Loading

0 comments on commit e999cb1

Please sign in to comment.